Source code for aitemplate.compiler.transform.apply_padding

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
Applies paddings to gemms based on alignment requirements.
"""
import logging
from typing import Callable, Dict, List

from aitemplate.compiler import ops

from aitemplate.compiler.base import _create_host_zero_tensor, IntImm, Operator, Tensor
from aitemplate.compiler.ops.gemm_universal.gemm_common import DimInfo, gemm, Source
from aitemplate.compiler.transform import transform_utils

from aitemplate.utils import alignment


_LOGGER = logging.getLogger(__name__)


def _extract_mnk_name(
    dim_info_dict: Dict[str, DimInfo], source: Source, tensor_idx: int, dim_idx: int
) -> str:
    for name, info_list in dim_info_dict.items():
        for info in info_list:
            if info == DimInfo(source, tensor_idx, [dim_idx]):
                return name
    return None


def get_padding_length(original_length: int, dtype: str) -> int:
    if alignment.valid_alignment(original_length, dtype):
        return 0

    # TODO(yingz): Tune padding strategy.
    if original_length < 16:
        return 1
    return int((original_length // 8 + 1) * 8) - original_length


def _pad_input_tensor(
    op: Operator,
    tensor_idx: int,
    f_extract_var_name: Callable[[int, int], str],
    alignment_var_to_padding_length: Dict[str, int],
    tensor_list: List[Tensor],
) -> None:
    original_shape = op._attrs["inputs"][tensor_idx]._attrs["shape"]
    for dim_idx, dim in enumerate(original_shape):
        tensor = op._attrs["inputs"][tensor_idx]
        original_tensor_debug_str = str(tensor)
        previous_shape = tensor._attrs["shape"]
        padding_shape = list(previous_shape)
        new_shape = list(previous_shape)

        var_name = f_extract_var_name(tensor_idx, dim_idx)
        if var_name is None or var_name not in alignment_var_to_padding_length:
            # This dim doesn't require alignment padding. Skipping.
            continue

        padding_length = alignment_var_to_padding_length.get(var_name)
        padding_shape[dim_idx] = IntImm(padding_length)
        new_shape[dim_idx] = IntImm(dim.value() + padding_length)
        tensor._attrs["dst_ops"].remove(op)

        padding_tensor = _create_host_zero_tensor(
            shape=padding_shape, dtype=tensor.dtype()
        )
        padded_tensor = ops.concatenate()(
            [tensor, padding_tensor],
            dim=dim_idx,
        )
        op._attrs["inputs"][tensor_idx] = padded_tensor
        padded_tensor._attrs["dst_ops"].add(op)
        tensor_list.append(padding_tensor)
        tensor_list.append(padded_tensor)

        _LOGGER.debug(
            "**** Apply padding ****, replace input tensor \n {} \n with \n {} \n".format(
                original_tensor_debug_str, padded_tensor
            ),
        )

    return


def _slice_output_tensor(
    new_output: Tensor, original_output: Tensor, tensor_list: List[Tensor]
) -> Tensor:
    new_shape = new_output._attrs["shape"]
    original_shape = original_output._attrs["shape"]
    if new_shape == original_shape:
        return new_output

    start_indicies = [0] * len(new_shape)
    end_indicies = [None] * len(new_shape)
    for i, (new_dim, old_dim) in enumerate(zip(new_shape, original_shape)):
        if new_dim != old_dim:
            assert isinstance(new_dim, IntImm) and isinstance(
                old_dim, IntImm
            ), f"new_shape: {new_shape}, old_shape: {original_shape}"
            assert (
                new_dim.value() > old_dim.value()
            ), f"new_shape: {new_shape}, old_shape: {original_shape}"
            end_indicies[i] = old_dim.value()
    sliced_tensor = ops.dynamic_slice()(new_output, start_indicies, end_indicies)
    tensor_list.append(sliced_tensor)
    sliced_tensor._attrs["is_output"] = new_output._attrs["is_output"]
    sliced_tensor._attrs["name"] = new_output._attrs["name"]
    new_output._attrs["name"] = None
    new_output._attrs["is_output"] = False
    return sliced_tensor


[docs]def apply_padding(sorted_graph: List[Tensor], workdir: str = None) -> List[Tensor]: """ Applies padding to gemms to use SM80 kernels. SM80 kernels require min_alignment == 2. """ visited_ops = set() new_sorted_graph = [] for tensor in sorted_graph: new_tensor_list = [tensor] src_ops = tensor.src_ops() for op in src_ops: if op in visited_ops: continue # Exclude special gemm kernels. if ( not isinstance(op, gemm) or isinstance(op, ops.gemm_rrr_small_nk) or isinstance(op, ops.bmm_rcr_n1) or isinstance(op, ops.bmm_rrr_k1_tanh) or "permute" in op._attrs["op"] ): continue # This pass only works for gemm or bmm. group_gemm is not supported. # We don't need to padd our special kernel bmm_rcr_n1, which does # not have any alignment constraint. op_name = op._attrs["name"] if op_name.startswith(("group_gemm", "bmm_rcr_n1")) or "softmax" in op_name: continue # Extract alignment var names and padding lengths. alignment_var_to_padding_length = {} dim_info_dict = op._extract_dims() for i, tensor in enumerate(op._attrs["inputs"]): alignment_var = _extract_mnk_name( dim_info_dict, Source.INPUT, i, len(tensor._attrs["shape"]) - 1 ) if alignment_var is None: # No alignment var is extracted. Skip padding. continue alignment_dim = tensor._attrs["shape"][-1] if not isinstance(alignment_dim, IntImm): raise NotImplementedError( "Gemm does not support dynamic alignment dimensions " "(i.e. alignment==1)! Gemm: {}".format(op) ) padding_length = get_padding_length( alignment_dim.value(), tensor.dtype() ) if padding_length > 0: alignment_var_to_padding_length[alignment_var] = padding_length if len(alignment_var_to_padding_length) == 0: # No padding is necessary. continue _LOGGER.debug( "**** Apply padding ****, alignment_var_to_padding_length: \n {} \n".format( alignment_var_to_padding_length ), ) original_op_debug_str = str(op) # Pad A and B. for tensor_idx, _ in enumerate(op._attrs["inputs"][:2]): _pad_input_tensor( op, tensor_idx, lambda tensor_idx, dim_idx: _extract_mnk_name( dim_info_dict, Source.INPUT, tensor_idx, dim_idx ), alignment_var_to_padding_length, new_tensor_list, ) # Pad bias and extra sources if necessary. for tensor_idx, tensor in enumerate(op._attrs["inputs"][2:]): _pad_input_tensor( op, tensor_idx + 2, # skip A and B lambda _, dim_idx: _extract_mnk_name( dim_info_dict, Source.OUTPUT, # bias alignment follows output alignment 0, # always check output[0] dim_idx + len(op._attrs["outputs"][0]._attrs["shape"]) - len(tensor._attrs["shape"]), # handle bias broadcast case ), alignment_var_to_padding_length, new_tensor_list, ) # Replaces the old op with the new op. for tensor_input in op._attrs["inputs"]: tensor_input._attrs["dst_ops"].discard(op) new_op = type(op)(**op._get_op_attributes()) new_op._attrs["split_k"] = op._attrs["split_k"] if "alpha" in op._attrs: new_op._attrs["alpha"] = op._attrs["alpha"] new_output = new_op(*op._attrs["inputs"]) new_tensor_list.append(new_output) original_output = op._attrs["outputs"][0] transform_utils.copy_tensor_attributes(new_output, original_output) # Slice output if necessary. new_output = _slice_output_tensor( new_output, original_output, new_tensor_list ) transform_utils.replace_tensor(original_output, new_output) transform_utils.remove_tensor_from_sorted_graph(original_output) _LOGGER.debug( "**** Apply padding ****, replace op \n {} \n with \n {} \n".format( original_op_debug_str, new_op ), ) new_sorted_graph.extend(new_tensor_list) new_sorted_graph = transform_utils.sanitize_sorted_graph(new_sorted_graph) return new_sorted_graph