#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
Batch_gather.
"""
import itertools
from collections import OrderedDict
from typing import List
import jinja2
from aitemplate import backend
from aitemplate.backend import registry
from aitemplate.compiler.base import IntVar, Operator, Tensor
# pylint: disable=C0103,W0221,W0102,W0223
EXEC_KEY_TEMPLATE = jinja2.Template(
    """
M == {{x_dim0}} && K == {{x_dim1}}
"""
)
[docs]class batch_gather(Operator):
    """
    Gathers values of the `input` tensor specified by `indicies`. Dim 0 of `indicies` correspond to the indices of `input` elements in dim 0.
    Args:
        input (Tensor): the source tensor
        indices (Tensor): the indices of elements to gather
    Returns:
        Tensor: the destination tensor
    """
    def __init__(self) -> None:
        super().__init__()
        self._attrs["op"] = "batch_gather"
        self._attrs["has_profiler"] = False
        self.exec_key_template = EXEC_KEY_TEMPLATE
    def _infer_shapes(self, x: Tensor, indices: Tensor) -> List[IntVar]:
        """Infers shapes for batch_gather."""
        rank = len(indices._attrs["shape"])
        # TODO: remove this when we're sure we support non-static batch_gather
        x_shape_values = [var._attrs["values"][0] for var in x._attrs["shape"]]
        indices_shape = [var._attrs["values"][0] for var in indices._attrs["shape"]]
        for r in range(1, rank - 1):
            assert x_shape_values[r] == indices_shape[r]
        out_shapes = x._attrs["shape"][:]
        if rank <= 1:
            # Special case: gather happens along batch dimension
            out_shapes[0] = indices.shape()[0]
        out_shapes[rank - 1] = indices._attrs["shape"][-1]
        return out_shapes
    def __call__(self, x: Tensor, indices: Tensor) -> Tensor:
        dtype = indices._attrs["dtype"]
        assert (
            dtype
            in [
                "int",
                "int32",
                "int64",
            ]
        ), f"batch_gather(): Expected dtype int/int32/int64 for index, got dtype {dtype}"
        self._attrs["inputs"] = [x, indices]
        self._set_depth()
        self._extract_exec_path(x)
        output_shape = self._infer_shapes(x, indices)
        output = Tensor(output_shape, src_ops={self}, dtype=x.dtype())
        self._attrs["outputs"] = [output]
        return output
    def _gen_exec_key(self, shape):
        return self.exec_key_template.render(
            x_dim0=shape[0],
            x_dim1=shape[1],
        ).replace("\n", "")
    def _extract_exec_path(self, x: Tensor):
        x_shape_values = [var._attrs["values"] for var in x._attrs["shape"]]
        x_shapes = itertools.product(*x_shape_values)
        self._attrs["exec_path"] = OrderedDict()
        for x_shape in x_shapes:
            key = self._gen_exec_key(x_shape)
            self._attrs["exec_path"][key] = ""
[docs]    def gen_function(self) -> str:
        target = backend.target.Target.current()
        func_key = "{target}.{op}.gen_function".format(
            target=target.name(), op=self._attrs["op"]
        )
        func = registry.get(func_key)
        return func(self._attrs)