# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Frontend for attention module
"""
from aitemplate.compiler import ops
from aitemplate.compiler.ops import flash_attention
from aitemplate.compiler.ops.common.epilogue import FuncEnum
from aitemplate.frontend import Tensor
from aitemplate.frontend.nn.dropout import Dropout
from aitemplate.frontend.nn.linear import Linear
from aitemplate.frontend.nn.module import Module
from aitemplate.frontend.nn.parameter import Parameter
from aitemplate.testing import detect_target
[docs]class FlashAttention(Module):
r"""FlashAttention provides an implementation for fused
multi-head attention module:
.. math::
\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK}{\sqrt(d)}) * V
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
"""
def __init__(
self,
batch_size,
max_seq_len,
dropout=0,
causal=False,
dtype="float16",
):
"""Initialize attention module, create a tensor for seqlen"""
super().__init__()
self.cu_length = Parameter(shape=[batch_size + 1], dtype="int32")
self.op = flash_attention(
batch_size=batch_size,
dropout=dropout,
max_seq_len=max_seq_len,
causal=causal,
)
[docs] def forward(self, *args):
"""forward pass for calling attention op"""
assert len(args) == 1
x = args[0]
return self.op(x, self.cu_length.tensor())
[docs]class MultiheadAttention(Module):
r"""Multi-Head Attention.
Allows the model to jointly attend to information
from different representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
Multi-Head Attention is defined as:
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
Args:
dim: total dimension of the model
batch_size: batch size
seq_len: sequence length
num_heads: Number of parallel attention heads. Default: 8
qkv_bias: whether to add bias to QKV. Default: False
attn_drop: Dropout probability on attention output weights. Default: ``0.0`` (no dropout).
proj_drop: Dropout probability on projection layers. Default: ``0.0`` (no dropout).
has_residual: has or has no residual. Default: `True`.
causal: default: `False`.
mask_seq: sequence mask, default: ``0``.
"""
USE_CUDA = None
def __init__(
self,
dim,
batch_size,
seq_len,
num_heads=8,
qkv_bias=False,
attn_drop=0.0,
proj_drop=0.0,
has_residual=True,
causal=False,
mask_seq=0,
use_mem_eff=False,
dtype="float16",
):
super().__init__()
assert (
dim % num_heads == 0
), f"dim {dim} should be divisible by num_heads {num_heads}"
if MultiheadAttention.USE_CUDA is None:
MultiheadAttention.USE_CUDA = detect_target().name() == "cuda"
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.causal = causal
self.has_residual = has_residual
self.mask_seq = mask_seq
self.use_mem_eff = use_mem_eff
flash_head_dims = {8, 16, 32, 64, 128}
# simple heuristic, may need refinement
self.use_flash = (
not (seq_len >= 512 and batch_size <= 2)
) and head_dim in flash_head_dims
# odd seq try use flash
if seq_len % 2 == 1:
self.use_flash = True
if use_mem_eff:
self.op = ops.mem_eff_attention(
causal=causal,
)
self.use_flash = False
else:
self.op = flash_attention(
batch_size=batch_size,
dropout=attn_drop,
max_seq_len=seq_len,
causal=causal,
)
# cu_length: the cumulative sequence lengths, used to index into hidden_states.
self.cu_length = Parameter(shape=[batch_size + 1], dtype="int32")
if self.mask_seq:
self.output_mask = Parameter(
shape=[mask_seq, num_heads, head_dim], dtype=dtype
)
if self.USE_CUDA:
# on CUDA flash_attention needs packed QKV as input,
# then do split + permute inside flash_attn
# input: (B, S, H)
# output: (B*S, 3, num_heads, head_dim)
if self.use_flash:
self.qkv = Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype)
else:
self.qkv = Linear(
dim,
dim * 3,
specialization="permute",
shape=(seq_len, 3, self.num_heads),
dtype=dtype,
)
else:
# on ROCM ck attention (bmm_softmax_bmm) takes three inputs (Q, K, V)
# here we generate packed QKV for splitting
# input: (B, seqlen, dim) -> (B*seqlen, dim)
# gemm: (B*seqlen, 3*dim)
# reshape to: (B, seqlen, 3, num_heads, head_dim)
# output: (3, B, num_heads, seqlen, head_dim)
self.qkv = Linear(
dim,
dim * 3,
specialization="permute",
shape=(seq_len, 3, self.num_heads),
layout="m2n3",
dtype=dtype,
)
self.attn_drop = Dropout(attn_drop, dtype=dtype)
self.proj = Linear(
dim, dim, specialization="add" if has_residual else None, dtype=dtype
)
self.proj_drop = Dropout(proj_drop, dtype=dtype)
def get_shape(self, x):
shape = [it.value() for it in x._attrs["shape"]]
return shape
def qkv_proj(self, x):
if self.USE_CUDA:
if self.use_flash:
batch, seq, hidden = self.get_shape(x)
out = self.qkv(x)
return ops.reshape()(
out, [int(batch * seq), 3, self.num_heads, hidden // self.num_heads]
)
else:
batch, seq, hidden = self.get_shape(x)
x = ops.reshape()(x, [-1, hidden])
return self.qkv(x)
else:
return self.qkv(x)
def attention(self, x):
# fused attention
# output: (B, Seqlen, num_heads, head_dim)
if self.USE_CUDA and self.use_flash:
# input(x): (B*seqlen, 3, num_heads, head_dim)
# output: (B, Seqlen, num_heads, head_dim)
return self.op(x, self.cu_length.tensor())
elif self.USE_CUDA and self.use_mem_eff:
(q, k, v) = ops.split()(x, 1, dim=0)
_, b, num_heads, seqlen, d = self.get_shape(q)
return self.op(
ops.reshape()(q, [b, -1, seqlen, d]),
ops.reshape()(k, [b, -1, seqlen, d]),
ops.reshape()(v, [b, -1, seqlen, d]),
)
else:
# input(q/k/v): (B*num_heads, seqlen, head_dim)
# attn = (B, S, H) * (B, S, H) = (B, S, S) #RCR
# softmax on dim -1 (B, S, S)
# attn@v: (B, S, S) * (B, S, H) = (B, S, H) #RRR
# reshape: (B, num_head, seqlen, head_dim)
# permute: (B, Seqlen, num_heads, head_dim)
if self.USE_CUDA:
scale = Tensor(
shape=[], dtype="float16", name="scale", value=self.scale
)
# [3, b, num_heads, seqlen, d]
_, b, num_heads, seqlen, d = self.get_shape(x)
# [3 * b * num_heads, seqlen, d]
x = ops.reshape()(x, [-1, seqlen, d])
(q, k, v) = ops.split()(x, b * num_heads, dim=0)
qk = ops.bmm_rcr()(q, k)
score = ops.elementwise(FuncEnum.MUL)(qk, scale)
score = ops.softmax()(score, -1)
out = ops.bmm_rrr_permute((num_heads,))(score, v)
else:
(q, k, v) = ops.split()(x, 1, dim=0)
_, _, _, seqlen, d = self.get_shape(q)
OP = ops.bmm_softmax_bmm_permute(
shape=(self.num_heads,),
scale=self.scale,
causal=self.causal,
)
out = OP(
ops.reshape()(q, [-1, seqlen, d]),
ops.reshape()(k, [-1, seqlen, d]),
ops.reshape()(v, [-1, seqlen, d]),
)
return out
[docs] def forward(self, *args):
"""forward pass for calling mha module"""
assert len(args) >= 1
x = args[0]
batch, seq, hidden = self.get_shape(x)
qkv = self.qkv_proj(x)
if self.mask_seq:
total = self.get_shape(qkv)[0]
qkv = ops.dynamic_slice()(
qkv,
start_indices=[0, 0, 0, 0],
end_indices=[total - self.mask_seq, None, None, None],
)
attn_output = self.attention(qkv)
if self.mask_seq:
attn_output = ops.concatenate()(
[attn_output, self.output_mask.tensor()], dim=0
)
attn_output = ops.reshape()(attn_output, [batch * seq, -1])
if self.has_residual:
assert len(args) == 2
x = self.proj(attn_output, args[1])
else:
x = self.proj(attn_output)
x = self.proj_drop(x)
x = ops.reshape()(x, [batch, seq, hidden])
return x
[docs]class CrossAttention(Module):
r"""Cross Multi-head Attention.
Allows the model to jointly attend to information
from different representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
Multi-Head Attention is defined as:
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
Args:
dim: total dimension of the model
batch_size: batch size
seq_len: sequence length
num_heads: Number of parallel attention heads. Default: 8
qkv_bias: whether to add bias to QKV. Default: False
attn_drop: Dropout probability on attention output weights. Default: ``0.0`` (no dropout).
proj_drop: Dropout probability on projection layers. Default: ``0.0`` (no dropout).
has_residual: has or has no residual. Default: `True`.
causal: default: `False`.
mask_seq: sequence mask, default: ``0``.
"""
def __init__(
self,
dim,
seq_len,
seq_len_kv,
num_heads,
qkv_bias=False,
attn_drop=0.0,
proj_drop=0.0,
has_residual=True,
causal=False,
dtype="float16",
):
super().__init__()
assert (
dim % num_heads == 0
), f"dim {dim} should be divisible by num_heads {num_heads}"
self.num_heads = num_heads
self.causal = causal
self.has_residual = has_residual
self.dim = dim
self.op = ops.mem_eff_attention(causal=causal)
self.proj_q = Linear(
dim,
dim,
bias=qkv_bias,
dtype=dtype,
)
self.proj_k = Linear(
dim,
dim,
bias=qkv_bias,
dtype=dtype,
)
self.proj_v = Linear(
dim,
dim,
bias=qkv_bias,
dtype=dtype,
)
self.attn_drop = Dropout(attn_drop, dtype=dtype)
self.proj = Linear(
dim, dim, specialization="add" if has_residual else None, dtype=dtype
)
self.proj_drop = Dropout(proj_drop, dtype=dtype)
def attention(self, q, k, v):
batch = q.shape()[0]
head_dim = self.dim // self.num_heads
query = self.proj_q(q)
key = self.proj_k(k)
value = self.proj_v(v)
query = ops.permute()(
ops.reshape()(query, [batch, -1, self.num_heads, head_dim]), [0, 2, 1, 3]
)
key = ops.permute()(
ops.reshape()(key, [batch, -1, self.num_heads, head_dim]), [0, 2, 1, 3]
)
value = ops.permute()(
ops.reshape()(value, [batch, -1, self.num_heads, head_dim]),
[0, 2, 1, 3],
)
return self.op(query, key, value)
[docs] def forward(self, *args):
"""forward pass for calling mha module"""
assert len(args) >= 3
x = args[0]
batch = x.shape()[0]
attn_output = self.attention(args[0], args[1], args[2])
attn_output = ops.reshape()(attn_output, [batch, -1, self.dim])
if self.has_residual:
assert len(args) == 4
x = self.proj(attn_output, args[3])
else:
x = self.proj(attn_output)
x = self.proj_drop(x)
x = ops.reshape()(x, [batch, -1, self.dim])
return x
[docs]class ScaledDotProductAttention(Module):
def __init__(self) -> None:
super().__init__()
[docs] def forward(self, q, k, v):
attn = ops.mem_eff_attention(causal=False)(q, k, v)
return attn