Source code for aitemplate.compiler.ops.tensor.slice_scatter

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
Slice_scatter.
"""

from aitemplate import backend
from aitemplate.backend import registry
from aitemplate.compiler.base import Operator
from aitemplate.compiler.stable_set import StableSet
from aitemplate.compiler.tensor_accessor import TensorAccessor

# pylint: disable=C0103,W0221


[docs]class slice_scatter(Operator): """This op represents a special fusion case where the inputs of a concatenate op all come from slice ops. In such a case, we can remove the concatenate op by placing each slice's output into the correct location in the original concatenate's output. """ @staticmethod def is_valid(cat_op: Operator) -> bool: if cat_op._attrs["op"] != "concatenate": return False if any( input_accessor.stride_dim is not None for input_accessor in cat_op._attrs["input_accessors"] ): return False return all( x._attrs["src_ops"] is not None and len(x._attrs["src_ops"]) == 1 and len(x._attrs["dst_ops"]) == 1 and list(x._attrs["src_ops"])[0]._attrs["op"] == "dynamic_slice" for x in cat_op._attrs["inputs"] ) def _update_inputs_outputs(self, cat_op): self._attrs["inputs"] = [] for slice_op in self._attrs["slice_ops"]: assert ( len(slice_op._attrs["inputs"]) == 1 ), "Slice op should only have 1 input! op: {}".format(slice_op) input_tensor = slice_op._attrs["inputs"][0] # A slice op's output may be fed into the same cat op multiple # times, so we make sure it's removed from the set only once. if slice_op in input_tensor._attrs["dst_ops"]: input_tensor._attrs["dst_ops"].remove(slice_op) input_tensor._attrs["dst_ops"].add(self) self._attrs["inputs"].append(input_tensor) # The original output of this slice_scatter op is the output of the cat_op. # We set the TensorAccessor, but will only use its offset field in the backend. self._attrs["output_accessors"] = [TensorAccessor(cat_op._attrs["outputs"][0])] self._attrs["outputs"] = cat_op._attrs["outputs"] for y in self._attrs["outputs"]: y._attrs["src_ops"] = StableSet({self}) for op in self._attrs["slice_ops"]: op._attrs["outputs"][0]._attrs["src_ops"] = StableSet() op._attrs["outputs"][0]._attrs["dst_ops"] = StableSet() for x in cat_op._attrs["inputs"]: x._attrs["src_ops"] = StableSet() x._attrs["dst_ops"] = StableSet() def __init__(self, scatter_dim: int) -> None: super().__init__() self._attrs["op"] = "slice_scatter" self._attrs["has_profiler"] = False self._attrs["scatter_dim"] = scatter_dim @staticmethod def make_op(cat_op: Operator) -> Operator: assert slice_scatter.is_valid(cat_op) scatter_dim = cat_op._attrs["concat_dim"] new_op = slice_scatter(scatter_dim) slice_ops = [] for x in cat_op._attrs["inputs"]: src_ops = x.src_ops() assert len(src_ops) == 1 slice_op = list(src_ops)[0] slice_ops.append(slice_op) new_op._attrs["slice_ops"] = slice_ops new_op._update_inputs_outputs(cat_op) new_op._set_depth() return new_op def __call__(self): raise RuntimeError("op {} cannot be called directly".format(self._attrs["op"])) def _get_op_attributes(self): raise NotImplementedError("slice_scatter get op attribute not implemented") def _get_func(self, fmt_str): target = backend.target.Target.current() func_key = fmt_str.format(target=target.name(), op=self._attrs["op"]) return registry.get(func_key)
[docs] def gen_function(self) -> str: func = self._get_func("{target}.{op}.gen_function") return func(self._attrs)
def _args_for_pseudo_code(self): return [f"scatter_dim={str(self._attrs['scatter_dim'])}]"]