Source code for aitemplate.compiler.ops.b2b_bmm.grouped_fmha_style_b2b_bmm

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

"""
Grouped back-to-back batched gemm fused kernel, implemented in FMHA style.
Computes bmm(causal_masks(alpha1(activation(alpha0 * bmm(Q, K) [+ bias]))), V),

where:
Q: [B_M0, H, K0] (row_major),
K: [B_N0, H, K0] (column_major),
V: [B_N0, H, N1] (row_major),
bias: [B, H, M0, N0] (row_major). Bias can be omitted.
B_M0, B_N0 are jagged dims.
Layouts are fixed for now.

causal_masks have 3 types:
NO_CAUSAL: no causal masks
UPPER_RIGHT_EMPTY: the upper right triangular part of the matrix is 0
LOWER_LEFT_EMPTY: the bottom left triangular part of the matrix is 0
When causal_masks is enabled, M0 must be equal to N0.

Internally this implementation stores the results of Q@K in shared memory.
It supports larger N0 / N1 compared to the classic_b2b_bmm implementation.
"""

from aitemplate.compiler.base import IntImm
from aitemplate.compiler.ops.b2b_bmm.fmha_style_b2b_bmm import (
    CausalType,
    fmha_style_b2b_bmm,
)
from aitemplate.utils import shape_utils


[docs]class grouped_fmha_style_b2b_bmm(fmha_style_b2b_bmm): """See comments at the head of this file.""" def __init__( self, causal_type: CausalType, epilogue_math_name: str, alpha0: float, alpha1: float, alpha1_divide_by_seq_len: bool = False, ) -> None: """Initialize grouped_fmha_style_b2b_bmm op. Check aitemplate.compiler.ops.b2b_bmm.b2b_bmm_base for more details about these args. """ super().__init__( causal_type, epilogue_math_name, alpha0, alpha1, alpha1_divide_by_seq_len ) self._attrs["op"] = "grouped_fmha_style_b2b_bmm" def _infer_shapes(self): """infer the output shape for grouped_fmha_style_b2b_bmm.""" q, k, v = self._attrs["inputs"][0:3] if not (q.is_jagged() and k.is_jagged() and v.is_jagged()): raise RuntimeError(f"{q=}, {k=}, {v=} must be jagged!") q_shape = q._attrs["shape"] k_shape = k._attrs["shape"] v_shape = v._attrs["shape"] if len(q_shape) != len(k_shape) or len(q_shape) != len(v_shape): raise RuntimeError( f"QKV ranks must be the same! QKV shapes: {q_shape=}, {k_shape=}, {v_shape=}." ) if len(q_shape) != 3: raise RuntimeError( f"QKV must have rank == 3! Current rank: {len(q_shape)}, QKV shapes: {q_shape=}, {k_shape=}, {v_shape=}." ) if q_shape[0] != k_shape[0] or q_shape[0] != v_shape[0]: raise RuntimeError( f"QKV must have same jagged_dim (batch_size and seq_length)! QKV shapes: {q_shape=}, {k_shape=}, {v_shape=}." ) if len(q_shape[0].jagged_dims()) != 1: raise RuntimeError(f"{len(q_shape[0].jagged_dims())=} must be 1!") if q_shape[1] != k_shape[1] or q_shape[1] != v_shape[1]: raise RuntimeError( f"QKV must have same head size! QKV shapes: {q_shape=}, {k_shape=}, {v_shape=}." ) K0 = q_shape[2] if K0 != k_shape[2]: raise RuntimeError( f"Q K shapes are not compatible! QKV shapes: {q_shape=}, {k_shape=}, {v_shape=}." ) num_heads = q_shape[1] output_shape = [q_shape[0], num_heads, v_shape[2]] if len(self._attrs["inputs"]) == 4: batch_size = q_shape[0].batch_dim() max_seq_length = q_shape[0].jagged_dims()[0].max_value() bias = self._attrs["inputs"][3] bias_shape = bias._attrs["shape"] bias_expected_shape = [ batch_size, num_heads, max_seq_length, max_seq_length, ] broadcastable, _ = shape_utils.get_broadcast_max_shape( bias_shape, bias_expected_shape ) if len(bias_shape) != 4: raise RuntimeError( f"Expected bias rank 4. Current bias rank: {len(bias)}." ) if not broadcastable: raise RuntimeError( f"bias shape is not compatible with Q K! " f"QKV shapes: {q_shape=}, {k_shape=}, {v_shape=}, " f"bias shapes: {bias_shape=}, {bias_expected_shape=}." ) if bias_shape[-1] != bias_expected_shape[-1]: raise RuntimeError( f"Bias last dim is not broadcastable! Expected shape: {bias_expected_shape[-1]}, current bias shape: {bias_shape}" ) # See comments below. if not isinstance(q_shape[0].jagged_dims()[0].min_value(), IntImm): raise RuntimeError( "Jagged dim' min value must be constant!" f"Current value: {q_shape[0].jagged_dims()=}" ) else: # Note: jagged_dims min / max values cannot be IntVar, as AIT lacks the feature to set # "attributes" dynamically at runtime in general. # # Assuming the case: Q @ K @ V, Q / K / V are all dense tensor inputs. # As a result, Q / K / V have total_length IntVar to represent the first dimension. # Then there are make_jagged() ops which take Q / K / V as well as # min_seq_len / max_seq_len IntVars as inputs. # At runtime, Q / K / V are inputs passed to AIT runtime. However, since # min_seq_len / max_seq_len is not bound to any input dimensions, # there are no ways for AIT to infer these values. As a result, AIT compilation would # fail. # # To support min_seq_len / max_seq_len IntVars, there must be a way dynamically set # them at runtime. # # When bias is set, max_seq_len can be inferred from bias input. if (not isinstance(q_shape[0].jagged_dims()[0].min_value(), IntImm)) or ( not isinstance(q_shape[0].jagged_dims()[0].max_value(), IntImm) ): raise RuntimeError( "Jagged dim' min / max values must be constant!" f"Current value: {q_shape[0].jagged_dims()=}" ) return output_shape