Source code for aitemplate.compiler.ops.tensor.slice_reshape_scatter

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
Slice_reshape_scatter.
"""

from typing import Optional

from aitemplate import backend
from aitemplate.backend import registry
from aitemplate.compiler.base import IntImm, IntVar, Operator
from aitemplate.compiler.stable_set import StableSet

from aitemplate.compiler.tensor_accessor import TensorAccessor

# pylint: disable=C0103,C0415,W0221


[docs]class slice_reshape_scatter(Operator): """represent slice + concat + reshape + concat pattern with slice + concat """ @staticmethod def is_valid(cat_op: Operator, reshape_op: Operator, cat_op_2: Operator) -> bool: assert cat_op._attrs["op"] == "concatenate" assert reshape_op._attrs["op"] == "reshape" assert cat_op_2._attrs["op"].startswith("concatenate") # only handle cases where two cat ops have the same concat_dim cat_dim = cat_op._attrs["concat_dim"] if cat_dim != cat_op_2._attrs["concat_dim"]: return False cat_output_shape = cat_op._attrs["outputs"][0]._attrs["shape"] cat_output_rank = len(cat_output_shape) if cat_output_rank <= 1: return False cat_output_shape_2 = cat_op_2._attrs["outputs"][0]._attrs["shape"] cat_output_rank_2 = len(cat_output_shape_2) # only handle cases where we are concatenating the last dim if cat_dim != cat_output_rank - 1: return False if cat_output_rank >= cat_output_rank_2: return False if not all( d1._attrs["values"][0] == d2._attrs["values"][0] for (d1, d2) in zip( cat_output_shape[:cat_dim], cat_output_shape_2[:cat_dim] ) ): return False reshape_to_shape = reshape_op._attrs["outputs"][0]._attrs["shape"] # skip dynamic shape if not all(isinstance(d, (IntImm, IntVar)) for d in reshape_to_shape): return False if not all( d1._attrs["values"][0] == d2._attrs["values"][0] for (d1, d2) in zip(cat_output_shape[:cat_dim], reshape_to_shape[:cat_dim]) ): return False return all( x._attrs["src_ops"] is not None and len(x._attrs["src_ops"]) == 1 and list(x._attrs["src_ops"])[0]._attrs["op"] == "dynamic_slice" for x in cat_op._attrs["inputs"] ) def _update_inputs_outputs(self, cat_op, reshape_op, cat_op_2): from aitemplate.compiler.transform import transform_utils idx = -1 for i, input_tensor in enumerate(cat_op_2._attrs["inputs"]): if input_tensor == reshape_op._attrs["outputs"][0]: idx = i break assert idx >= 0 # The original output of this slice_reshape_scatter op is the output # of the reshape op. self._attrs["output_accessors"] = [ TensorAccessor(reshape_op._attrs["outputs"][0]) ] cat_op_2.remove_input_at(idx) transform_utils.remove_single_tensor_op_from_sorted_graph(reshape_op) self._attrs["inputs"] = [ op._attrs["inputs"][0] for op in self._attrs["slice_ops"] ] cat_op_2_outputs = cat_op_2._attrs["outputs"] assert len(cat_op_2_outputs) == 1, ( f'{cat_op_2._attrs["name"]=} may only have one output, but got more ' f"{cat_op_2_outputs=}" ) self._attrs["outputs"] = cat_op_2_outputs # setup output TensorAccessor offset = 0 cat_dim = cat_op_2._attrs["concat_dim"] orig_idx = -1 for i, input_tensor in enumerate(cat_op_2._attrs["original_inputs"]): if input_tensor == reshape_op._attrs["outputs"][0]: orig_idx = i break input_tensor_shape = input_tensor._attrs["shape"] offset += input_tensor_shape[cat_dim].value() assert orig_idx >= 0, ( f'could not find {input_tensor._attrs["name"]=} in the original_inputs' "of cat_op_2" ) self._attrs["output_accessors"][0].update_base_tensor( cat_op_2_outputs[0], cat_dim, offset ) for x in self._attrs["inputs"]: x._attrs["dst_ops"].add(self) for y in self._attrs["outputs"]: y._attrs["src_ops"].add(self) for op in self._attrs["slice_ops"]: op._attrs["outputs"][0]._attrs["src_ops"] = StableSet() op._attrs["outputs"][0]._attrs["dst_ops"] = StableSet() for x in cat_op._attrs["inputs"]: x._attrs["src_ops"] = StableSet() x._attrs["dst_ops"] = StableSet() for y in cat_op._attrs["outputs"]: y._attrs["src_ops"] = StableSet() y._attrs["dst_ops"] = StableSet() def __init__(self, scatter_dim: int, element_func: Optional[str] = None) -> None: super().__init__() self._attrs["element_func"] = element_func self._attrs["op"] = "slice_reshape_scatter" self._attrs["has_profiler"] = False self._attrs["scatter_dim"] = scatter_dim @staticmethod def make_op(cat_op: Operator, reshape_op: Operator, cat_op_2: Operator) -> Operator: assert slice_reshape_scatter.is_valid(cat_op, reshape_op, cat_op_2) element_func = None if cat_op_2._attrs["op"] == "concatenate_tanh": element_func = "fast_tanh" scatter_dim = cat_op._attrs["concat_dim"] new_op = slice_reshape_scatter(scatter_dim, element_func) slice_ops = [] for x in cat_op._attrs["inputs"]: src_ops = x.src_ops() assert len(src_ops) == 1 slice_op = list(src_ops)[0] slice_ops.append(slice_op) new_op._attrs["slice_ops"] = slice_ops new_op._update_inputs_outputs(cat_op, reshape_op, cat_op_2) new_op._set_depth() return new_op def __call__(self): raise RuntimeError("op {} cannot be called directly".format(self._attrs["op"])) def _get_func(self, fmt_str): target = backend.target.Target.current() func_key = fmt_str.format(target=target.name(), op=self._attrs["op"]) return registry.get(func_key)
[docs] def gen_function(self) -> str: func = self._get_func("{target}.{op}.gen_function") return func(self._attrs, self._attrs["element_func"])