Source code for aitemplate.compiler.ops.attention.flash_attention

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
Flash attention.
"""

import itertools
from collections import OrderedDict
from typing import List

import jinja2

from aitemplate import backend
from aitemplate.backend import registry
from aitemplate.compiler.base import Operator, Tensor
from aitemplate.utils import shape_utils

# pylint: disable=C0103,W0221,W0102,W0223

SHAPE_FUNC_TEMPLATE = jinja2.Template(
    """
{{indent}}{{dtype}}total = {{x_dim0}};
{{indent}}{{dtype}}num_heads = {{x_dim2}};
{{indent}}{{dtype}}head_sizes = {{x_dim3}};
{{indent}}{{dtype}}NO = total;
{{indent}}{{dtype}}HO = num_heads;
{{indent}}{{dtype}}WO = head_sizes;
"""
)

EXEC_KEY_TEMPLATE = jinja2.Template(
    """
total == {{x_dim0}} && num_heads == {{x_dim2}} && head_sizes == {{x_dim3}}
"""
)


[docs]class flash_attention(Operator): r"""FlashAttention provides an implementation for fused multi-head attention module: .. math:: \text{Attention}(Q, K, V) = \text{softmax}(\frac{QK}{\sqrt(d)}) * V .. math:: \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. """ def __init__(self, batch_size, dropout, max_seq_len, causal) -> None: """Initialize attention module""" super().__init__() assert dropout == 0 self._attrs["op"] = "flash_attention" self._attrs["has_profiler"] = False self._attrs["batch_size"] = batch_size self._attrs["dropout"] = dropout self._attrs["max_seq_len"] = max_seq_len self._attrs["seq_len"] = 512 self._attrs["head_size"] = -1 self._attrs["causal"] = causal self._attrs["workspace"] = 0 self.exec_key_template = EXEC_KEY_TEMPLATE self.shape_eval_template = SHAPE_FUNC_TEMPLATE def _infer_shape(self, x: List[int], w: List[int]): eval_func = self.shape_eval_template.render( indent="", dtype="", div="//", x_dim0=x[0], x_dim1=x[1], x_dim2=x[2], x_dim3=x[3], ) output = {} exec(eval_func, output) # noqa: P204 return [ int(output["NO"]), int(output["HO"]), int(output["WO"]), ] def _infer_shapes(self, x: Tensor, w: Tensor): """infer the output shape for attention module""" x_shape_values = [var._attrs["values"] for var in x._attrs["shape"]] x_shapes = itertools.product(*x_shape_values) w_shape = [var._attrs["values"][0] for var in w._attrs["shape"]] # run infer shape for each y_shapes = [] for x_shape in x_shapes: y_shape = self._infer_shape(x_shape, w_shape) y_shapes.append(y_shape) def unique(vector): return sorted(set(vector)) output_shape = [ shape_utils.gen_int_var(unique([d[0] for d in y_shapes])), shape_utils.gen_int_var(unique([d[1] for d in y_shapes])), shape_utils.gen_int_var(unique([d[2] for d in y_shapes])), ] return output_shape def __call__(self, x: Tensor, cu_seqlens: Tensor) -> Tensor: """call the op Parameters ---------- x : float16 QKV tensor shape: (batch*seqlen, 3, num_heads, head_size) cu_seqlens : int seq lens tensor shape (batch_size + 1) Returns ---------- Tensor """ self._attrs["inputs"] = [x, cu_seqlens] self._set_depth() self._extract_exec_path(x) output_shape = self._infer_shapes(x, cu_seqlens) output = Tensor(output_shape, src_ops={self}) batch_size = self._attrs["batch_size"] max_seq_len = self._attrs["max_seq_len"] total = x._attrs["shape"][0]._attrs["values"][0] num_heads = x._attrs["shape"][2]._attrs["values"][0] head_size = x._attrs["shape"][3]._attrs["values"][0] assert head_size in [8, 16, 32, 64, 128] self._attrs["head_size"] = head_size base_N = 256 # SM80 if max_seq_len <= 128: seq_len = 128 elif max_seq_len <= 256: seq_len = 256 else: seq_len = ((max_seq_len + base_N - 1) // base_N) * base_N self._attrs["seq_len"] = seq_len self._attrs["workspace"] = ( 4 * num_heads * (total * head_size + batch_size * seq_len) ) self._attrs["outputs"] = [output] return output def _get_op_attributes(self): target_attrs = ["batch_size", "dropout", "max_seq_len", "causal"] attr = {} for target_attr in target_attrs: if target_attr in self._attrs: attr[target_attr] = self._attrs[target_attr] return attr def _gen_exec_key(self, shape): """rendering shape info""" return self.exec_key_template.render( x_dim0=shape[0], x_dim1=shape[1], x_dim2=shape[2], x_dim3=shape[3], ).replace("\n", "") def _extract_exec_path(self, x: Tensor): x_shape_values = [var._attrs["values"] for var in x._attrs["shape"]] x_shapes = itertools.product(*x_shape_values) self._attrs["exec_path"] = OrderedDict() for x_shape in x_shapes: key = self._gen_exec_key(x_shape) self._attrs["exec_path"][key] = ""
[docs] def gen_function(self) -> str: """call backend functions""" target = backend.target.Target.current() func_key = "{target}.{op}.gen_function".format( target=target.name(), op=self._attrs["op"] ) func = registry.get(func_key) return func(self._attrs)