Source code for aitemplate.compiler.ops.gemm_universal.group_gemm_rcr

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
Grouped GEMM Specialization for A[RowMajor], B[ColMajor], C[RowMajor]
"""

import logging
import re
from collections import OrderedDict
from typing import List

import jinja2

from aitemplate.backend import registry
from aitemplate.backend.target import Target
from aitemplate.compiler.base import ExecItem, Tensor
from aitemplate.compiler.ops.gemm_universal import gemm_common as common
from aitemplate.compiler.ops.gemm_universal.gemm_rcr import gemm_rcr
from aitemplate.compiler.ops.tensor import concatenate

from aitemplate.compiler.stable_set import StableSet
from aitemplate.compiler.tensor_accessor import TensorAccessor

# pylint: disable=C0103,W0223,W0221,W0613


_LOGGER = logging.getLogger(__name__)

SHAPE_EVAL_TEMPLATE = jinja2.Template(
    """
{% for operand_dim in group_operand_dims %}
{% set output_addr = output_addr_cals[loop.index - 1] %}
{% set input_a_addr = input_a_addr_cals[loop.index - 1] %}
{{indent}}{{dtype}}GROUP_{{loop.index0}}_AM = {{operand_dim[0]}};
{{indent}}{{dtype}}GROUP_{{loop.index0}}_AK = {{operand_dim[1]}};
{{indent}}{{dtype}}GROUP_{{loop.index0}}_BN = {{operand_dim[2]}};
{{indent}}{{dtype}}GROUP_{{loop.index0}}_BK = {{operand_dim[3]}};
{{indent}}{{dtype}}GROUP_{{loop.index0}}_CM = GROUP_{{loop.index0}}_AM;
{{indent}}{{dtype}}GROUP_{{loop.index0}}_CN = GROUP_{{loop.index0}}_BN;

{{input_a_addr}}
{{output_addr}}

{{indent}}{{dtype}}GROUP_{{loop.index0}}_M = {{operand_dim[0]}};
{{indent}}{{dtype}}GROUP_{{loop.index0}}_K = {{operand_dim[1]}};
{{indent}}{{dtype}}GROUP_{{loop.index0}}_N = {{operand_dim[2]}};
{% endfor %}
{% for operand_dim in group_operand_dims %}
{{indent}}{{operand_dim[4]}} = GROUP_{{loop.index0}}_M;
{{indent}}{{operand_dim[5]}} = GROUP_{{loop.index0}}_N;
{% endfor %}
"""
)


EXEC_KEY_TEMPLATE = jinja2.Template(
    """
{% for mnk in group_mnk %} {% if loop.index0 != 0 %} && {% endif %}
GROUP_{{loop.index0}}_M == {{mnk[0]}} &&
GROUP_{{loop.index0}}_N == {{mnk[1]}} &&
GROUP_{{loop.index0}}_K == {{mnk[2]}}
{% endfor %}
"""
)


[docs]class group_gemm_rcr(common.gemm): """Grouped GEMM Specialization: GEMM_RCR(A, B) This operator is equivalent to the following pytorch code: .. highlight:: python .. code-block:: python # group 1 A1 = torch.randn(M1, K1).cuda().half() B1 = torch.randn(N1, K1).cuda().half() y1 = torch.nn.functional.linear(A1, B1) ... # group n An = torch.randn(Mn, Kn).cuda().half() Bn = torch.randn(Nn, Kn).cuda().half() yn = torch.nn.functional.linear(An, Bn) """ def __init__(self): super().__init__() self.shape_eval_template = SHAPE_EVAL_TEMPLATE self._attrs["op"] = "group_gemm_rcr" # this is a state flag will be codegen self._attrs["int_state_flag"] = 0 def cal_align_ab(m, n, k): return common.default_align_ab(k, k, self._attrs["inputs"][0].dtype()) self._attrs["f_ab_alignment"] = cal_align_ab def _invert_exec_key(self, key): tmp = re.findall(r"==\s*(\d+)", key) return [int(x) for x in tmp] def _extract_exec_path(self, dynamic_profiling_strategy=None): # FIXME: Make this API properly support dynamic_profiling_strategy. if dynamic_profiling_strategy is not None: return # check batch dim same for each group batch_dim = self._attrs["inputs"][0]._attrs["shape"][0] for i in range(self._attrs["groups"]): if batch_dim != self._attrs["inputs"][i * 2]._attrs["shape"][0]: raise RuntimeError( "Batch dim is different in groups. Inputs: {}".format( self._attrs["inputs"] ) ) # for each batch create exec_path self._attrs["exec_path"] = OrderedDict() for m_value in batch_dim._attrs["values"]: group_mnk = [] for i in range(self._attrs["groups"]): b = self._attrs["inputs"][i * 2 + 1] mnk = [m_value] mnk.append(b._attrs["shape"][0]._attrs["values"][0]) mnk.append(b._attrs["shape"][1]._attrs["values"][0]) group_mnk.append(mnk) exec_key = EXEC_KEY_TEMPLATE.render(group_mnk=group_mnk).replace("\n", "") self._attrs["exec_path"][exec_key] = ExecItem( profiling_key=exec_key, exec_cond=exec_key, algo="", ) def _gen_profile_cmd(self, profiler_prefix, cfg, exec_key): def fbuild_cmd(exec_key): mnk_flat = self._invert_exec_key(exec_key) cmd = [] cmd.append(self._attrs["groups"]) cmd.extend(mnk_flat) return cmd return super()._gen_profile_cmd(profiler_prefix, cfg, exec_key, fbuild_cmd) def _concat_strided_outputs(self, outputs, output_stride_dim): """a temporary function to concatenate strided outputs""" cat_op = concatenate() cat_output = cat_op(outputs, dim=output_stride_dim) cat_output._attrs["src_ops"] = StableSet([self]) offset = 0 for idx, output_tensor in enumerate(outputs): self._attrs["output_accessors"][idx].update_base_tensor( cat_output, output_stride_dim, offset ) offset += output_tensor._attrs["shape"][output_stride_dim]._attrs["values"][ 0 ] from aitemplate.compiler.transform import transform_utils transform_utils.remove_tensor_from_sorted_graph(output_tensor) return cat_output @staticmethod def _one_input_accessors( input_accessors: List[TensorAccessor], num_inputs_per_group: int, idx: int ) -> List[TensorAccessor]: return [ a for i, a in enumerate(input_accessors) if i % num_inputs_per_group == idx ] def input_a_accessors(self) -> List[TensorAccessor]: return self._one_input_accessors( self._attrs["input_accessors"], num_inputs_per_group=2, idx=0 ) def input_b_accessors(self) -> List[TensorAccessor]: return self._one_input_accessors( self._attrs["input_accessors"], num_inputs_per_group=2, idx=1 ) def __call__(self, operand_groups: List[List[Tensor]], output_stride_dim=None): # FIXME: when output_stride_dim is specified, we will concat the outputs of the # grouped gemm along the output_stride_dim axis. It's a temporary solution for # a pattern where the outputs of a grouped gemm can be concatenated # to form a single larger tensor. We will write a pass to detect such a # pattern automatically. self._attrs["inputs"] = [] ret = [] epilogue_alignment = 8 for a, b in operand_groups: op = gemm_rcr() c = op(a, b) c._attrs["src_ops"] = StableSet([self]) a._attrs["dst_ops"].remove(op) b._attrs["dst_ops"].remove(op) epilogue_alignment = min( op._attrs["epilogue_alignment"], epilogue_alignment ) ret.append(c) self._attrs["inputs"].append(a) self._attrs["inputs"].append(b) self._set_depth() self._attrs["input_accessors"] = [ TensorAccessor(a) for i, a in enumerate(self._attrs["inputs"]) ] self._attrs["output_accessors"] = [TensorAccessor(c) for c in ret] self._attrs["groups"] = len(ret) if output_stride_dim is not None: # FIXME: replace this manual concat with an automated pass if output_stride_dim != 1: raise RuntimeError( "only support cases where output_stride_dim equals to 1" ) self._attrs["output_stride_dim"] = output_stride_dim ret = self._concat_strided_outputs(ret, output_stride_dim) self._attrs["outputs"] = [ret] else: self._attrs["outputs"] = ret self._attrs["epilogue_alignment"] = epilogue_alignment self._extract_exec_path() # This is a lazy way to allocate space for args # Reserve 12 * 4 * len(groups) byte for each field # 12 is read of sizeof(GemmCoord) # problem_sizes_device # ptrA/B/C/D # lda/b/c/d # problem_sizes_device: N * GemmCoord -> N * 3 * sizeof(int64_t) -> 32 * N # ptrA/B/C/D: N * 8 for each # lda/b/c/d: N * 8 for each # total: N * 8 * 4 + N * 8 * 4 + N * 8 * 4 # total: 3 * 32 * N args_size = 96 * self._attrs["groups"] self._attrs["unique_workspace"] = args_size return ret
[docs] def gen_profiler( self, workdir: str = None, dynamic_profiling_strategy=None ) -> None: """Generate profiler for the op Parameters ---------- workdir : str, optional [description], by default None dynamic_profiling_strategy: DynamicProfileStrategy, optional A dynamic profiling strategy, used to filter generated profiles at compile time. See also: :func:`~aitemplate.compiler.transform.profile.profile` """ target = Target.current() # init candidate ops func_key = "{target}.{op}.config".format( target=target.name(), op=self._attrs["op"] ) func = registry.get(func_key) func(self._attrs) # init exec path self._extract_exec_path(dynamic_profiling_strategy) # init compile-time filter workloads = list(self._attrs["exec_path"].keys()) ab_alignments = sorted({self._get_ab_alignment(wkl) for wkl in workloads}) assert 1 == len( ab_alignments ), f"ab_alignments should be the same among all workloads, got {ab_alignments=}" func_key = "{target}.{op}.filter".format( target=target.name(), op=self._attrs["op"] ) filter_func = registry.get(func_key) # run compile-time filter new_op_instance = OrderedDict( (k, v) for k, v in self._attrs["op_instance"].items() if filter_func(k, self._attrs, ab_alignments[0]) ) _LOGGER.debug( f"Filtered profiler kernels for {self._attrs['op']}: reduced the " f"number of generated kernels from {len(self._attrs['op_instance'])} " f"to {len(new_op_instance)}", ) _LOGGER.debug( f"Group_gemm profiler valid configs: {sorted(new_op_instance.keys())}", ) self._attrs["op_instance"] = new_op_instance build_profiler = super()._should_build_profiler(workloads, new_op_instance) if build_profiler: func_key = "{target}.{op}.gen_profiler".format( target=target.name(), op=self._attrs["op"] ) func = registry.get(func_key) profiler_filename = self._get_profiler_filename() _LOGGER.info(f"generating {profiler_filename=}") return func( self._attrs, workdir, profiler_filename, self.shape_eval_template )
[docs] def gen_function(self) -> str: """Generate function for the op Returns ------- str C++ source code of the function """ target = Target.current() func_key = "{target}.{op}.gen_function".format( target=target.name(), op=self._attrs["op"] ) func = registry.get(func_key) return func( self._attrs, self.exec_cond_template, self.shape_eval_template, )