Source code for aitemplate.compiler.ops.tensor.batch_gather

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
Batch_gather.
"""

import itertools
from collections import OrderedDict
from typing import List

import jinja2

from aitemplate import backend
from aitemplate.backend import registry
from aitemplate.compiler.base import IntVar, Operator, Tensor

# pylint: disable=C0103,W0221,W0102,W0223

EXEC_KEY_TEMPLATE = jinja2.Template(
    """
M == {{x_dim0}} && K == {{x_dim1}}
"""
)


[docs]class batch_gather(Operator): """ Gathers values of the `input` tensor specified by `indicies`. Dim 0 of `indicies` correspond to the indices of `input` elements in dim 0. Args: input (Tensor): the source tensor indices (Tensor): the indices of elements to gather Returns: Tensor: the destination tensor """ def __init__(self) -> None: super().__init__() self._attrs["op"] = "batch_gather" self._attrs["has_profiler"] = False self.exec_key_template = EXEC_KEY_TEMPLATE def _infer_shapes(self, x: Tensor, indices: Tensor) -> List[IntVar]: """Infers shapes for batch_gather.""" rank = len(indices._attrs["shape"]) # TODO: remove this when we're sure we support non-static batch_gather x_shape_values = [var._attrs["values"][0] for var in x._attrs["shape"]] indices_shape = [var._attrs["values"][0] for var in indices._attrs["shape"]] for r in range(1, rank - 1): assert x_shape_values[r] == indices_shape[r] out_shapes = x._attrs["shape"][:] if rank <= 1: # Special case: gather happens along batch dimension out_shapes[0] = indices.shape()[0] out_shapes[rank - 1] = indices._attrs["shape"][-1] return out_shapes def __call__(self, x: Tensor, indices: Tensor) -> Tensor: dtype = indices._attrs["dtype"] assert ( dtype in [ "int", "int32", "int64", ] ), f"batch_gather(): Expected dtype int/int32/int64 for index, got dtype {dtype}" self._attrs["inputs"] = [x, indices] self._set_depth() self._extract_exec_path(x) output_shape = self._infer_shapes(x, indices) output = Tensor(output_shape, src_ops={self}, dtype=x.dtype()) self._attrs["outputs"] = [output] return output def _gen_exec_key(self, shape): return self.exec_key_template.render( x_dim0=shape[0], x_dim1=shape[1], ).replace("\n", "") def _extract_exec_path(self, x: Tensor): x_shape_values = [var._attrs["values"] for var in x._attrs["shape"]] x_shapes = itertools.product(*x_shape_values) self._attrs["exec_path"] = OrderedDict() for x_shape in x_shapes: key = self._gen_exec_key(x_shape) self._attrs["exec_path"][key] = ""
[docs] def gen_function(self) -> str: target = backend.target.Target.current() func_key = "{target}.{op}.gen_function".format( target=target.name(), op=self._attrs["op"] ) func = registry.get(func_key) return func(self._attrs)