Source code for aitemplate.compiler.ops.gemm_special.batched_dense_vec_jagged_2d_mul

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

"""
Define batched_dense_vec_jagged_2d_mul op
"""

from typing import List

from aitemplate.backend import registry

from aitemplate.backend.target import Target

from aitemplate.compiler.base import IntVar, Operator, Tensor


[docs]class batched_dense_vec_jagged_2d_mul(Operator): """ Compute a dense tensor containing batched matrix multiplication of a batched dense vector and a batched jagged matrix. Args: vectors (Tensor): batched dense vector of shape [B, H, N]. matrices (Tensor): batched jagged matrix of shape [sum_B(N_B), H, D]. Returns: output (Tensor): dense tensor containing the batched vector / jagged matrix multiplication result of shape [B, H, D]. """ def __init__( self, ): super().__init__() self._attrs["op"] = "batched_dense_vec_jagged_2d_mul" def _infer_shape(self, vectors: Tensor, matrices: Tensor) -> List[IntVar]: jagged_int_var = matrices.shape()[0] return [jagged_int_var.batch_dim(), matrices.shape()[1], matrices.shape()[2]] def __call__(self, vectors: Tensor, matrices: Tensor) -> Tensor: if not matrices.is_jagged(): raise TypeError( f"matrices must be a jagged Tensor, but got a dense Tensor {matrices}." ) if vectors.is_jagged(): raise TypeError( f"vectors must be a jagged Tensor, but got a jagged Tensor {vectors}." ) if len(vectors.shape()) != 3: raise ValueError(f"vectors must be rank-3, but got {vectors}.") if len(matrices.shape()) != 3: raise ValueError(f"matrices must be rank-3, but got {matrices}.") jagged_int_var = matrices.shape()[0] if jagged_int_var.batch_dim() != vectors.shape()[0]: raise RuntimeError( "The batch dim B of the jagged matrices tensor and " "dense vectors tensor must be the same, but got " f"{jagged_int_var.batch_dim()=} != {vectors.shape()[0]=}." ) if vectors.shape()[1] != matrices.shape()[1]: raise RuntimeError( f"The second dim H of the jagged matrices tensor and " "dense vectors tensor must be the same, but got " f"{matrices.shape()[1]=} != {vectors.shape()[1]}." ) if vectors.dtype() != matrices.dtype(): raise RuntimeError( "vectors and matrices must have the same type, but got " f"{vectors.dtype()=} != {matrices.dtype()=}." ) if len(jagged_int_var.jagged_dims()) != 1: raise RuntimeError( "Jagged matrices tensor must have a " f"single JaggedDim, but got {matrices}." ) else: max_value = jagged_int_var.jagged_dims()[0].max_value() if max_value != vectors.shape()[2]: raise RuntimeError( "Upper bound (max_value) of the jagged dim in matrices " "must be equal to the last dim N in vectors, but got " f"{max_value=} != {vectors.shape()[2].value()=}." ) self._attrs["inputs"] = [vectors, matrices] self._set_depth() output_shape = self._infer_shape(vectors, matrices) output = Tensor(output_shape, src_ops={self}, dtype=vectors.dtype()) self._attrs["outputs"] = [output] return output
[docs] def gen_function(self) -> str: target = Target.current() func = registry.get(f"{target.name()}.{self._attrs['op']}.gen_function") return func(self._attrs)