Source code for aitemplate.compiler.ops.gemm_special.bmm_rcr_n1

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
GEMM Specialization for A[RowMajor], B[ColMajor], C[RowMajor]
This is special in template based gemm solution
This is used for `torch.nn.functional.linear`
When use for `linear`, need set A->Data, B->Weight

Special kernel for GEMV case:
A: [B, M, K]
B: [B, N, K]
C: [B, M, N]
where N = 1

This kernel computes C = alpha * A @ B
"""

from aitemplate.compiler.base import IntImm, Tensor
from aitemplate.compiler.ops.gemm_universal import bmm_rcr
from aitemplate.compiler.tensor_accessor import TensorAccessor

# pylint: disable=C0103, W0223, W0221, W0613


[docs]class bmm_rcr_n1(bmm_rcr): def __init__(self): super().__init__() self._attrs["op"] = "bmm_rcr_n1" self._attrs["f_ab_alignment"] = True self._attrs["has_profiler"] = False
[docs] @staticmethod def is_valid_shape(a: Tensor, b: Tensor): """ Check input a/b is valid for bmm_rcr_n1. Requirements: 1) matching dimension of a/b (where a is row major, b is column major) 2) dim N of b needs to be 1 3) dim K of b needs to be multiple of 8 """ if len(a.shape()) != 3 or len(b.shape()) != 3: return False valid = True valid &= a.shape()[0] == b.shape()[0] valid &= a.shape()[2] == b.shape()[2] # check N = 1 BN = b.shape()[1] if not isinstance(BN, IntImm): return False valid &= BN.value() == 1 # check BK is static dim BK = b.shape()[2] if not isinstance(BK, IntImm): return False return valid
def _infer_shapes(self, a: Tensor, b: Tensor): assert self.is_valid_shape( a, b ), "shape (tensor a:{}, tensor b:{}) not valid for bmm_rcr_n1".format( a.shape(), b.shape() ) return super()._infer_shapes(a, b) def __call__(self, a: Tensor, b: Tensor, alpha: float = 1.0) -> Tensor: self._attrs["inputs"] = [a, b] self._attrs["alpha"] = alpha self._set_depth() output_shape = self._infer_shapes(a, b) output = Tensor(output_shape, src_ops={self}, dtype=a.dtype()) self._attrs["outputs"] = [output] self._attrs["input_accessors"] = [TensorAccessor(a), TensorAccessor(b)] self._attrs["output_accessors"] = [TensorAccessor(output)] return output
[docs] def gen_profiler( self, workdir: str = None, dynamic_profiling_strategy=None ) -> None: """This kernel doesn't require profiling.""" return