Source code for aitemplate.frontend.nn.attention

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
Frontend for attention module
"""
from aitemplate.compiler import ops
from aitemplate.compiler.ops import flash_attention
from aitemplate.compiler.ops.common.epilogue import FuncEnum
from aitemplate.frontend import Tensor
from aitemplate.frontend.nn.dropout import Dropout
from aitemplate.frontend.nn.linear import Linear
from aitemplate.frontend.nn.module import Module
from aitemplate.frontend.nn.parameter import Parameter
from aitemplate.testing import detect_target


[docs]class FlashAttention(Module): r"""FlashAttention provides an implementation for fused multi-head attention module: .. math:: \text{Attention}(Q, K, V) = \text{softmax}(\frac{QK}{\sqrt(d)}) * V .. math:: \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. """ def __init__( self, batch_size, max_seq_len, dropout=0, causal=False, dtype="float16", ): """Initialize attention module, create a tensor for seqlen""" super().__init__() self.cu_length = Parameter(shape=[batch_size + 1], dtype="int32") self.op = flash_attention( batch_size=batch_size, dropout=dropout, max_seq_len=max_seq_len, causal=causal, )
[docs] def forward(self, *args): """forward pass for calling attention op""" assert len(args) == 1 x = args[0] return self.op(x, self.cu_length.tensor())
[docs]class MultiheadAttention(Module): r"""Multi-Head Attention. Allows the model to jointly attend to information from different representation subspaces as described in the paper: `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_. Multi-Head Attention is defined as: .. math:: \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. Args: dim: total dimension of the model batch_size: batch size seq_len: sequence length num_heads: Number of parallel attention heads. Default: 8 qkv_bias: whether to add bias to QKV. Default: False attn_drop: Dropout probability on attention output weights. Default: ``0.0`` (no dropout). proj_drop: Dropout probability on projection layers. Default: ``0.0`` (no dropout). has_residual: has or has no residual. Default: `True`. causal: default: `False`. mask_seq: sequence mask, default: ``0``. """ USE_CUDA = None def __init__( self, dim, batch_size, seq_len, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, has_residual=True, causal=False, mask_seq=0, use_mem_eff=False, dtype="float16", ): super().__init__() assert ( dim % num_heads == 0 ), f"dim {dim} should be divisible by num_heads {num_heads}" if MultiheadAttention.USE_CUDA is None: MultiheadAttention.USE_CUDA = detect_target().name() == "cuda" self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.causal = causal self.has_residual = has_residual self.mask_seq = mask_seq self.use_mem_eff = use_mem_eff flash_head_dims = {8, 16, 32, 64, 128} # simple heuristic, may need refinement self.use_flash = ( not (seq_len >= 512 and batch_size <= 2) ) and head_dim in flash_head_dims # odd seq try use flash if seq_len % 2 == 1: self.use_flash = True if use_mem_eff: self.op = ops.mem_eff_attention( causal=causal, ) self.use_flash = False else: self.op = flash_attention( batch_size=batch_size, dropout=attn_drop, max_seq_len=seq_len, causal=causal, ) # cu_length: the cumulative sequence lengths, used to index into hidden_states. self.cu_length = Parameter(shape=[batch_size + 1], dtype="int32") if self.mask_seq: self.output_mask = Parameter( shape=[mask_seq, num_heads, head_dim], dtype=dtype ) if self.USE_CUDA: # on CUDA flash_attention needs packed QKV as input, # then do split + permute inside flash_attn # input: (B, S, H) # output: (B*S, 3, num_heads, head_dim) if self.use_flash: self.qkv = Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype) else: self.qkv = Linear( dim, dim * 3, specialization="permute", shape=(seq_len, 3, self.num_heads), dtype=dtype, ) else: # on ROCM ck attention (bmm_softmax_bmm) takes three inputs (Q, K, V) # here we generate packed QKV for splitting # input: (B, seqlen, dim) -> (B*seqlen, dim) # gemm: (B*seqlen, 3*dim) # reshape to: (B, seqlen, 3, num_heads, head_dim) # output: (3, B, num_heads, seqlen, head_dim) self.qkv = Linear( dim, dim * 3, specialization="permute", shape=(seq_len, 3, self.num_heads), layout="m2n3", dtype=dtype, ) self.attn_drop = Dropout(attn_drop, dtype=dtype) self.proj = Linear( dim, dim, specialization="add" if has_residual else None, dtype=dtype ) self.proj_drop = Dropout(proj_drop, dtype=dtype) def get_shape(self, x): shape = [it.value() for it in x._attrs["shape"]] return shape def qkv_proj(self, x): if self.USE_CUDA: if self.use_flash: batch, seq, hidden = self.get_shape(x) out = self.qkv(x) return ops.reshape()( out, [int(batch * seq), 3, self.num_heads, hidden // self.num_heads] ) else: batch, seq, hidden = self.get_shape(x) x = ops.reshape()(x, [-1, hidden]) return self.qkv(x) else: return self.qkv(x) def attention(self, x): # fused attention # output: (B, Seqlen, num_heads, head_dim) if self.USE_CUDA and self.use_flash: # input(x): (B*seqlen, 3, num_heads, head_dim) # output: (B, Seqlen, num_heads, head_dim) return self.op(x, self.cu_length.tensor()) elif self.USE_CUDA and self.use_mem_eff: (q, k, v) = ops.split()(x, 1, dim=0) _, b, num_heads, seqlen, d = self.get_shape(q) return self.op( ops.reshape()(q, [b, -1, seqlen, d]), ops.reshape()(k, [b, -1, seqlen, d]), ops.reshape()(v, [b, -1, seqlen, d]), ) else: # input(q/k/v): (B*num_heads, seqlen, head_dim) # attn = (B, S, H) * (B, S, H) = (B, S, S) #RCR # softmax on dim -1 (B, S, S) # attn@v: (B, S, S) * (B, S, H) = (B, S, H) #RRR # reshape: (B, num_head, seqlen, head_dim) # permute: (B, Seqlen, num_heads, head_dim) if self.USE_CUDA: scale = Tensor( shape=[], dtype="float16", name="scale", value=self.scale ) # [3, b, num_heads, seqlen, d] _, b, num_heads, seqlen, d = self.get_shape(x) # [3 * b * num_heads, seqlen, d] x = ops.reshape()(x, [-1, seqlen, d]) (q, k, v) = ops.split()(x, b * num_heads, dim=0) qk = ops.bmm_rcr()(q, k) score = ops.elementwise(FuncEnum.MUL)(qk, scale) score = ops.softmax()(score, -1) out = ops.bmm_rrr_permute((num_heads,))(score, v) else: (q, k, v) = ops.split()(x, 1, dim=0) _, _, _, seqlen, d = self.get_shape(q) OP = ops.bmm_softmax_bmm_permute( shape=(self.num_heads,), scale=self.scale, causal=self.causal, ) out = OP( ops.reshape()(q, [-1, seqlen, d]), ops.reshape()(k, [-1, seqlen, d]), ops.reshape()(v, [-1, seqlen, d]), ) return out
[docs] def forward(self, *args): """forward pass for calling mha module""" assert len(args) >= 1 x = args[0] batch, seq, hidden = self.get_shape(x) qkv = self.qkv_proj(x) if self.mask_seq: total = self.get_shape(qkv)[0] qkv = ops.dynamic_slice()( qkv, start_indices=[0, 0, 0, 0], end_indices=[total - self.mask_seq, None, None, None], ) attn_output = self.attention(qkv) if self.mask_seq: attn_output = ops.concatenate()( [attn_output, self.output_mask.tensor()], dim=0 ) attn_output = ops.reshape()(attn_output, [batch * seq, -1]) if self.has_residual: assert len(args) == 2 x = self.proj(attn_output, args[1]) else: x = self.proj(attn_output) x = self.proj_drop(x) x = ops.reshape()(x, [batch, seq, hidden]) return x
[docs]class CrossAttention(Module): r"""Cross Multi-head Attention. Allows the model to jointly attend to information from different representation subspaces as described in the paper: `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_. Multi-Head Attention is defined as: .. math:: \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. Args: dim: total dimension of the model batch_size: batch size seq_len: sequence length num_heads: Number of parallel attention heads. Default: 8 qkv_bias: whether to add bias to QKV. Default: False attn_drop: Dropout probability on attention output weights. Default: ``0.0`` (no dropout). proj_drop: Dropout probability on projection layers. Default: ``0.0`` (no dropout). has_residual: has or has no residual. Default: `True`. causal: default: `False`. mask_seq: sequence mask, default: ``0``. """ def __init__( self, dim, seq_len, seq_len_kv, num_heads, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, has_residual=True, causal=False, dtype="float16", ): super().__init__() assert ( dim % num_heads == 0 ), f"dim {dim} should be divisible by num_heads {num_heads}" self.num_heads = num_heads self.causal = causal self.has_residual = has_residual self.dim = dim self.op = ops.mem_eff_attention(causal=causal) self.proj_q = Linear( dim, dim, bias=qkv_bias, dtype=dtype, ) self.proj_k = Linear( dim, dim, bias=qkv_bias, dtype=dtype, ) self.proj_v = Linear( dim, dim, bias=qkv_bias, dtype=dtype, ) self.attn_drop = Dropout(attn_drop, dtype=dtype) self.proj = Linear( dim, dim, specialization="add" if has_residual else None, dtype=dtype ) self.proj_drop = Dropout(proj_drop, dtype=dtype) def attention(self, q, k, v): batch = q.shape()[0] head_dim = self.dim // self.num_heads query = self.proj_q(q) key = self.proj_k(k) value = self.proj_v(v) query = ops.permute()( ops.reshape()(query, [batch, -1, self.num_heads, head_dim]), [0, 2, 1, 3] ) key = ops.permute()( ops.reshape()(key, [batch, -1, self.num_heads, head_dim]), [0, 2, 1, 3] ) value = ops.permute()( ops.reshape()(value, [batch, -1, self.num_heads, head_dim]), [0, 2, 1, 3], ) return self.op(query, key, value)
[docs] def forward(self, *args): """forward pass for calling mha module""" assert len(args) >= 3 x = args[0] batch = x.shape()[0] attn_output = self.attention(args[0], args[1], args[2]) attn_output = ops.reshape()(attn_output, [batch, -1, self.dim]) if self.has_residual: assert len(args) == 4 x = self.proj(attn_output, args[3]) else: x = self.proj(attn_output) x = self.proj_drop(x) x = ops.reshape()(x, [batch, -1, self.dim]) return x
[docs]class ScaledDotProductAttention(Module): def __init__(self) -> None: super().__init__()
[docs] def forward(self, q, k, v): attn = ops.mem_eff_attention(causal=False)(q, k, v) return attn