Source code for aitemplate.compiler.ops.tensor.split

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
Split.
"""

from typing import List, Sequence, Union

from aitemplate import backend
from aitemplate.backend import registry
from aitemplate.compiler.base import IntImm, IntVar, Operator, Tensor
from aitemplate.utils import shape_utils
from aitemplate.utils.tensor_utils import wrap_dim

# pylint: disable=C0103,W0221


[docs]class split(Operator): """Splits the tensor into chunks on the specified dimension. Args: x (Tensor): tensor to split. split_sizes (List[int]) : list of sizes for each chunk dim (int): dimension along which to split the tensor Returns: List[Tensor]: the list of output tensors Example: .. highlight:: python .. code-block:: python >>> X = Tensor(shape=[2, 1], name="X", is_input=True) >>> Y = ops.split()(X, 2, dim=0) [Tensor(shape=[IntImm(1), IntImm(1)]), Tensor(shape=[IntImm(1), IntImm(1)])] """ def __init__(self) -> None: super().__init__() self._attrs["op"] = "split" self._attrs["has_profiler"] = False def _infer_shapes( self, x: Tensor, split_sizes: List[int], dim: int ) -> List[IntVar]: """Infers shapes for split.""" x_shape = x._attrs["shape"] rank = len(x_shape) if rank <= 0: raise RuntimeError("expected a non-scalar tensor") if dim >= rank: raise RuntimeError( f"split dim ({dim}) expected to be less than rank ({rank})" ) num_splits = len(split_sizes) if num_splits < 1: raise RuntimeError( f"the number of splits expected >=0 but got {num_splits}" ) split_dim_size = x_shape[dim]._attrs["values"][0] if sum(split_sizes) != split_dim_size: raise RuntimeError( f"sum of split_sizes ({split_sizes}) does not match split_dim_size ({split_dim_size})" ) output_shapes = [] for split_size in split_sizes: output_shape = x_shape[:dim] + [IntImm(split_size)] + x_shape[dim + 1 :] output_shapes.append(output_shape) return output_shapes def __call__(self, x: Tensor, split_size_or_sections, dim=0) -> List[Tensor]: x_shape = x._attrs["shape"] self._attrs["inputs"] = [x] dim = wrap_dim(dim, x._rank()) self._attrs["split_dim"] = dim self._set_depth() if isinstance(split_size_or_sections, (List, tuple)): split_size_or_sections = [ shape_utils.convert_IntVar_to_int(d) for d in split_size_or_sections ] split_sizes = list(split_size_or_sections) else: split_size = shape_utils.convert_IntVar_to_int(split_size_or_sections) if not isinstance(split_size, int): raise RuntimeError("split_size expected to be of int") # TODO: support split along dynamic axis if not isinstance(x_shape[dim], IntImm): raise NotImplementedError("split dynamic axis") split_dim_size = x_shape[dim].value() if split_dim_size == 0: # a special case - it's valid in pytorch num_splits = 1 split_sizes = [0] else: if split_size == 0: raise RuntimeError("split_size expected to be > 0") num_splits = int((split_dim_size + split_size - 1) / split_size) split_sizes = [split_size] * num_splits split_sizes[num_splits - 1] = split_size - ( split_size * num_splits - split_dim_size ) self._attrs["split_sizes"] = split_sizes output_shapes = self._infer_shapes(x, split_sizes, dim) outputs = [ Tensor(output_shape, src_ops={self}, dtype=x._attrs["dtype"]) for output_shape in output_shapes ] self._attrs["outputs"] = outputs self._attrs["original_outputs"] = list(outputs) # True means the corresponding output tensor will be materialized by backend. self._attrs["output_masks"] = [True] * len(outputs) # torch returns a tuple, so do we return tuple(outputs) def _get_func(self, fmt_str): target = backend.target.Target.current() func_key = fmt_str.format(target=target.name(), op=self._attrs["op"]) return registry.get(func_key)
[docs] def gen_function(self) -> str: func = self._get_func("{target}.{op}.gen_function") return func(self._attrs)
[docs] def remove_output_at(self, indices: Union[int, Sequence[int]]) -> None: """ This function removes the outputs in indices from the "outputs" attribute and sets output_masks[indices] to be False. Note that the indices are based on the current "outputs". Parameters ---------- indices : Union[int, Sequence[int]] the index of an output or indices of multiple outputs based on the current "outputs" Returns ------- None """ if isinstance(indices, int): indices = [indices] else: indices = list(indices) curr_outputs = self._attrs["outputs"] num_curr_outputs = len(curr_outputs) assert ( len(indices) <= num_curr_outputs ), f"Expected len(indices) <= num_curr_outputs, but got {len(indices)} and {num_curr_outputs}" num_original_outputs = len(self._attrs["original_outputs"]) num_output_masks = len(self._attrs["output_masks"]) assert num_original_outputs == num_output_masks, ( f"original_outputs and output_masks must have the same length, " f"but got {num_original_outputs} and {num_output_masks}" ) curr_idx = 0 # index into curr_outputs idx = 0 # index into indices new_outputs = [] # we need to skip those indices where output_masks have been modified. for orig_idx in range(num_original_outputs): if not self._attrs["output_masks"][orig_idx]: continue if idx < len(indices) and curr_idx == indices[idx]: if not self._attrs["output_masks"][orig_idx]: raise RuntimeError( f'Expected input_masks at {idx} to be True for {self._attrs["name"]}' ) self._attrs["output_masks"][orig_idx] = False idx += 1 else: new_outputs.append(curr_outputs[curr_idx]) curr_idx += 1 num_new_outputs = len(new_outputs) assert num_new_outputs + len(indices) == num_curr_outputs, ( f"Expected num_new_outputs + len(indices) == num_curr_outputs, " f"but got {num_new_outputs + len(indices)} and {num_curr_outputs}" ) self._attrs["outputs"] = new_outputs
def _inputs_for_pseudo_code(self): return self._attrs["inputs"] def _args_for_pseudo_code(self): return [ f"split_sizes={str(self._attrs['split_sizes'])}]", f"dim={str(self._attrs['split_dim'])}]", ]