# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Common functions/classes for GEMM ops
"""
import itertools
import logging
import math
import os
import re
from collections import OrderedDict
from dataclasses import dataclass
from enum import Enum
from hashlib import sha1
from operator import itemgetter
from time import sleep
from typing import Any, Callable, Dict, List, Union
import jinja2
from aitemplate import backend
from aitemplate.backend import registry
from aitemplate.backend.profiler_runner import ProfileResult
from aitemplate.compiler.base import (
DynamicProfileStrategy,
ExecItem,
IntImm,
IntVar,
Operator,
Tensor,
)
from aitemplate.compiler.dtype import is_same_dtype
from aitemplate.compiler.ops.gemm_universal.cache_entry import (
GemmQueryEntry,
GemmRecordEntry,
)
from aitemplate.compiler.tensor_accessor import TensorAccessor
from aitemplate.utils import alignment, environ
# pylint: disable=C0103,R1711,W0102,W0221,E1120
_LOGGER = logging.getLogger(__name__)
def split_k_result_getter(result):
return result[1].duration
EXEC_COND_TEMPLATE = jinja2.Template(
"""
{{indent}}if ({{cond}}) {
{{indent}} {{program}}
{{indent}}}
"""
)
class Source(Enum):
INPUT = 1
OUTPUT = 2
@dataclass
class DimInfo:
"""Class to record dimension info."""
def __init__(
self,
source: Source,
tensor_idx: int,
dim_idx: Union[int, List[int]],
placeholder: bool = False,
):
"""
source:
Source.INPUT or Source.OUTPUT
tensor_idx:
Depending on source, extract info from inputs[tensor_idx] or outputs[tensor_idx]
dim_idx:
Extract shape from inputs/outputs[tensor_idx][dim_idx]
placeholder:
If True, the diminfo might not be accurate in compile time, just a placeholder to be filled afterwards
This is useful to handle issue such as broadcasting which B might not be exact.
"""
self.source = source
self.tensor_idx = tensor_idx
if isinstance(dim_idx, int):
dim_idx = [dim_idx]
self.dim_idx = dim_idx
self.placeholder = placeholder
source: Source
tensor_idx: int
dim_idx: List[int]
placeholder: bool
def extract_shape_from_accessor(func_attrs, source: Source, idx: int):
if source == Source.INPUT:
if "input_accessors" in func_attrs:
return func_attrs["input_accessors"][idx].original_shapes
return func_attrs["inputs"][idx].shape()
elif source == Source.OUTPUT:
if "output_accessors" in func_attrs:
return func_attrs["output_accessors"][idx].original_shapes
return func_attrs["outputs"][idx].shape()
else:
raise RuntimeError(f"Unknown source, got {source}")
def create_input_batch_diminfo(input_shapes, batch_dims, output_batch):
"""
Create inputs' batch diminfo.
Provided input_shapes and the corresponding batch_dims, this function
returns a list of batch's DimInfo of the inputs.
input_shapes:
A list of input shapes.
batch_dims:
The batch dimension for the corresponding input_shapes.
If length of corresponding input's shape is less than 2, neglected.
output_batch:
The batch size for output.
"""
assert len(input_shapes) == len(batch_dims)
batch_diminfo = []
for idx, input_shape in enumerate(input_shapes):
if len(input_shape) > 2:
batch_diminfo.append(
DimInfo(
Source.INPUT,
tensor_idx=idx,
dim_idx=batch_dims[idx],
placeholder=input_shape[batch_dims[idx]] != output_batch,
)
)
return batch_diminfo
def group_gemm_inverse_key_func(key):
m_pattern = re.compile(r"GROUP_\d+_M\s*==\s*(\d+)")
all_m = re.findall(m_pattern, key)
n_pattern = re.compile(r"GROUP_\d+_N\s*==\s*(\d+)")
all_n = re.findall(n_pattern, key)
k_pattern = re.compile(r"GROUP_\d+_K\s*==\s*(\d+)")
all_k = re.findall(k_pattern, key)
assert len(all_m) == len(all_n) == len(all_n)
return (all_m, all_n, all_k)
def gemm_inverse_key_func(key):
tmp = re.findall(r"(\d+)", key)
return [int(x) for x in tmp]
def default_align_ab(a, b, dtype):
ab = math.gcd(a, b)
return alignment.find_max_alignment(ab, dtype)
def _to_list(elem):
if isinstance(elem, tuple):
return list(elem)
else:
return [elem]
def _check_with_retries(
condition: Callable[[], bool],
max_attempts: int = 3,
delay_seconds: int = 5,
) -> bool:
"""Check a condition with retries."""
attempts = 0
while True:
if condition():
return True
attempts += 1
if attempts >= max_attempts:
return False
sleep(delay_seconds)
class gemm(Operator):
"""Base gemm operators"""
def __init__(
self,
):
super().__init__()
self._attrs["op"] = "gemm"
self._attrs["has_profiler"] = True
self._attrs["f_ab_alignment"] = None
self._attrs["epilogue_alignment"] = 1
self._attrs["epilogue"] = "LinearCombination"
self._attrs["workspace"] = 0
self._attrs["split_k"] = 1
self._attrs["num_sources"] = 0
self._attrs["alpha"] = 1.0
self._attrs["permute_shape"] = ""
self.exec_cond_template = EXEC_COND_TEMPLATE
def _extract_epilogue_alignment(
self, output_shape: List[Any], dynamic_profiling_strategy=None
) -> None:
epilogue_dim = output_shape[-1]
if isinstance(epilogue_dim, int):
shape = epilogue_dim
elif not isinstance(epilogue_dim, IntImm):
# The alignment inferred here will be set to 1 during codegen.
if dynamic_profiling_strategy is None:
return
elif dynamic_profiling_strategy == DynamicProfileStrategy.MAX:
shape = epilogue_dim.upper_bound()
elif dynamic_profiling_strategy == DynamicProfileStrategy.MIN:
shape = epilogue_dim.lower_bound()
else:
raise RuntimeError(
f"Unsupported dynamic profiling strategy: {dynamic_profiling_strategy}"
)
else:
shape = epilogue_dim._attrs["values"][0]
dtype = self._attrs["inputs"][0].dtype()
self._attrs["epilogue_alignment"] = alignment.find_max_alignment(shape, dtype)
def _infer_shapes(self, a: Tensor, b: Tensor):
raise NotImplementedError("_infer_shapes() is not implemented!")
def _gen_exec_key(self, name_value_mapping):
key_strs = []
for name, values in name_value_mapping.items():
if len(values) == 1:
key_strs.append(f"{name} == {values[0]}")
elif len(values) > 1:
key_strs.append(f"{name} >= {values[0]} && {name} <= {values[-1]}")
else:
raise RuntimeError("Gemm input has empty dim values: {}".format(values))
return " && ".join(key_strs)
def _extract_dims(self, for_profiling: bool = False) -> Dict[str, List[DimInfo]]:
"""Extracts a mapping between dim names and a list of DimInfo.
This function will be used in gemm shape inference, gemm padding graph
transformation, gemm profiling, etc.
All subclasses must implement this API.
An example result from gemm_rcr:
{
"M": [
DimInfo(source=INPUT, tensor_idx=0, dim_idx=0),
DimInfo(source=OUTPUT, tensor_idx=0, dim_idx=0),
],
"K": [
DimInfo(source=INPUT, tensor_idx=0, dim_idx=1),
DimInfo(source=INPUT, tensor_idx=1, dim_idx=1),
],
"N": [
DimInfo(source=INPUT, tensor_idx=1, dim_idx=0),
DimInfo(source=OUTPUT, tensor_idx=0, dim_idx=1),
],
}
Parameters
----------
for_profiling: bool
Whether this function is used for generating profiling source codes.
If yes, some DimInfo are simplified. e.g. For gemm, we treat all tensors
as 2d.
"""
raise NotImplementedError("extract_dims() is not implemented!")
def _extract_exec_path(self, dynamic_profiling_strategy):
"""Extracts profiling keys and execution conditions for a given dynamic_profiling_strategy.
This function fills in self._attrs["exec_path"].
Keys are "exec_key"s, and are used for profiling.
Values are ItemValues, where "profiling_key" fields are the same as the corresponding keys,
"exec_cond" fields specify dynamic ranges, and "algo" fields are empty for now.
e.g. for gemm_rrr, input1=[m, k], input2=[k, n]
m = 1, k = 128, n = 256.
self._attrs["exec_path"] = {
"M==1 && K==128 && N==256" : ItemValue(
profiling_key="M==1 && K==128 && N==256",
exec_cond="M==1 && K==128 && N==256",
algo="",
)
}
e.g. for gemm_rrr, input1=[dynamic_m, k], input2=[k, n]
dynamic_m >= 1 and dynamic_m <= 1024, dynamic_profiling_strategy = MAX,
k = 128, n = 256.
self._attrs["exec_path"] = {
"M==1024 && K==128 && N==256" : ItemValue(
profiling_key="M==1024 && K==128 && N==256",
exec_cond="M>=1 && M<=1024 && K==128 && N==256",
algo="",
)
}
Parameters
----------
dynamic_profiling_strategy : DynamicProfileStrategy
See comments for DynamicProfileStrategy.
"""
dim_info_dict: Dict[str, List[DimInfo]] = self._extract_dims()
dim_dict: Dict[str, List[IntVar]] = {}
for name, dim_infos in dim_info_dict.items():
dim_info = None
for d in dim_infos:
if d.placeholder:
continue
if dim_info is None:
dim_info = d
elif d.source == Source.INPUT:
# input should have priority.
dim_info = d
assert dim_info is not None, f"Couldn't find valid dim info for dim {name}"
tensor_list = (
self._attrs["inputs"]
if dim_info.source == Source.INPUT
else self._attrs["outputs"]
)
if dim_info.source == Source.INPUT and "input_accessors" in self._attrs:
dim_dict[name] = _to_list(
itemgetter(*(dim_info.dim_idx))(
self._attrs["input_accessors"][
dim_info.tensor_idx
].original_shapes
)
)
elif dim_info.source == Source.OUTPUT and "output_accessors" in self._attrs:
dim_dict[name] = _to_list(
itemgetter(*(dim_info.dim_idx))(
self._attrs["output_accessors"][
dim_info.tensor_idx
].original_shapes
)
)
else:
dim_dict[name] = _to_list(
itemgetter(*(dim_info.dim_idx))(
tensor_list[dim_info.tensor_idx]._attrs["shape"]
)
)
shape_values_dict = {}
for name, dims in dim_dict.items():
min_value = math.prod([dim.lower_bound() for dim in dims])
max_value = math.prod([dim.upper_bound() for dim in dims])
shape_values_dict[name] = sorted({min_value, max_value})
self._attrs["exec_path"] = OrderedDict()
if dynamic_profiling_strategy == DynamicProfileStrategy.MAX:
max_values = {
name: [max(shape_values)]
for name, shape_values in shape_values_dict.items()
}
exec_item = ExecItem(
profiling_key=self._gen_exec_key(max_values),
exec_cond=self._gen_exec_key(shape_values_dict),
algo="",
)
self._attrs["exec_path"][exec_item.profiling_key] = exec_item
elif dynamic_profiling_strategy == DynamicProfileStrategy.MIN:
min_values = {
name: [min(shape_values)]
for name, shape_values in shape_values_dict.items()
}
exec_item = ExecItem(
profiling_key=self._gen_exec_key(min_values),
exec_cond=self._gen_exec_key(shape_values_dict),
algo="",
)
self._attrs["exec_path"][exec_item.profiling_key] = exec_item
else:
raise NotImplementedError(
"Gemm only supports MIN or MAX dynamic profiling! "
"Current dynamic_profiling_strategy: {}".format(
dynamic_profiling_strategy
)
)
def _get_profiler_filename(self):
"""
generate a filename for a profiler that benchmarks multiple GEMM instances
"""
target = backend.target.Target.current()
op_type = self._attrs["op"]
all_op_names = list(self._attrs["op_instance"].keys())
encoded_str = sha1((";".join(all_op_names)).encode("utf-8")).hexdigest()
if target.use_dummy_profiling_results():
# we don't use cache
return f"{op_type}_{encoded_str}"
else:
cache_ver = target.get_profile_cache_version("gemm")
return f"{op_type}_{encoded_str}_{cache_ver}"
def _should_build_profiler(
self, workloads: List[str], new_op_instance: OrderedDict
):
"""
Check if we should build profilers. If we have a cached
entry for this gemm instance, we update this gemm op's
relevant attributes with the cached result and return False.
"""
# We are forced to use the cache, so we skip building profilers.
if environ.force_profiler_cache():
return False
target = backend.target.Target.current()
build_profiler = True
# Now, let's query if all of our workloads have cache entries. If that
# is the case, it is safely to skip generating and building profilers.
if not target.use_dummy_profiling_results():
tmp_key = next(iter(new_op_instance.keys()))
tmp_op = new_op_instance[tmp_key]
build_profiler = False
for wkl in workloads:
exec_entry_sha1 = sha1(wkl.encode("utf-8")).hexdigest()
query = GemmQueryEntry(
# 1 is subtracted from the type enum values for consistency with the existing
# cache databases; due to the "void" type being added to the DataType enum as
# the very first enum member (and shifting the values of other enum members) in
# https://github.com/NVIDIA/cutlass/commit/7c04f954151f606e60608061e891785fba229ae2
dtype_a=tmp_op.A.element.value - 1,
dtype_b=tmp_op.B.element.value - 1,
dtype_c=tmp_op.C.element.value - 1,
dtype_acc=tmp_op.accumulator_type().value - 1,
major_a=tmp_op.A.layout.value,
major_b=tmp_op.B.layout.value,
major_c=tmp_op.C.layout.value,
op_type=self._attrs["op"],
device=target._arch,
epilogue=tmp_op.epilogue_functor.value,
exec_entry_sha1=exec_entry_sha1,
pshape=self._attrs["permute_shape"],
)
cache_value = target.query_profile_cache("gemm", query.__dict__)
if cache_value is not None and not target.force_profile():
_LOGGER.info(
f'Load profiling result for {self._attrs["name"]} '
f"from cache: {cache_value}",
)
best_algo, workspace, split_k = cache_value
self._attrs["exec_path"][wkl].algo = best_algo
self._attrs["workspace"] = max(self._attrs["workspace"], workspace)
self._attrs["split_k"] = split_k
else:
# cache miss - we will have to generate and build profilers
build_profiler = True
return build_profiler
def gen_profiler(
self, workdir: str = None, dynamic_profiling_strategy=DynamicProfileStrategy.MAX
) -> None:
"""Generate profilers for this gemm op.
Parameters
----------
workdir : str, optional
Output dir of profilers, by default None
dynamic_profiling_strategy: DynamicProfileStrategy, optional
A dynamic profiling strategy, used to filter generated profiles at compile time.
See also: :func:`~aitemplate.compiler.transform.profile.profile`
"""
target = backend.target.Target.current()
# init candidate ops
func_key = "{target}.{op}.config".format(
target=target.name(), op=self._attrs["op"]
)
func = registry.get(func_key)
func(self._attrs, dtype=self._attrs["inputs"][0]._attrs["dtype"])
# init exec path
self._extract_exec_path(dynamic_profiling_strategy)
# init compile-time filter
workloads = list(self._attrs["exec_path"].keys())
ab_alignments = sorted({self._get_ab_alignment(wkl) for wkl in workloads})
assert 1 == len(
ab_alignments
), f"ab_alignments should be the same among all workloads, got {ab_alignments=}"
func_key = "{target}.{op}.filter".format(
target=target.name(), op=self._attrs["op"]
)
# Update epilogue alignment here because it may be different depending on the profiling strategy.
# Note that this alignment is only used in profiling and will be updated
# during the final codegen.
# gemm_permute ops have special output alignment rules, skip here.
if "layout" not in self._attrs:
output_shape = self._attrs["output_accessors"][0].original_shapes
self._extract_epilogue_alignment(output_shape, dynamic_profiling_strategy)
if not self._attrs["op_instance"]:
raise RuntimeError(
f"No GEMM op instances were generated for {self._attrs['op']}."
)
filter_func = registry.get(func_key)
# run compile-time filter
new_op_instance = OrderedDict(
(k, v)
for k, v in self._attrs["op_instance"].items()
if filter_func(k, self._attrs, ab_alignments[0])
)
_LOGGER.debug(
f"Filtered profiler kernels for {self._attrs['op']}: reduced the "
f"number of generated kernels from {len(self._attrs['op_instance'])} "
f"to {len(new_op_instance)}",
)
self._attrs["op_instance"] = new_op_instance
if not self._attrs["op_instance"]:
raise RuntimeError(
f"No GEMM op instances are left after filtering for {self._attrs['op']}. "
"This is probably due to incompatible alignment requirements."
)
build_profiler = self._should_build_profiler(workloads, new_op_instance)
if build_profiler:
# generate profiler
func_key = "{target}.{op}.gen_profiler".format(
target=target.name(), op=self._attrs["op"]
)
func = registry.get(func_key)
profiler_filename = self._get_profiler_filename()
_LOGGER.info(f"generating {profiler_filename=}")
return func(
self._attrs,
workdir,
profiler_filename,
self._extract_dims(for_profiling=True),
)
def _gen_profile_cmd(
self, profiler_prefix, profiler_filename, exec_key, fbuild_cmd
):
exe_path = os.path.join(profiler_prefix, profiler_filename)
if not _check_with_retries(
condition=lambda: os.access(exe_path, os.X_OK),
max_attempts=3,
delay_seconds=5,
):
raise RuntimeError("Profiler %s is not executable" % exe_path)
cmd_args = fbuild_cmd(exec_key)
cmd = [exe_path]
# mnk
cmd.extend(cmd_args)
command = [str(x) for x in cmd]
# profiling gemm/bmm_permute with layout and shape for ROCM
if self._attrs.get("shape") is not None:
if backend.target.Target.current().name() == "rocm":
for x in self._attrs["shape"]:
command.append(str(x))
return command
def _split_k_search_space(self, M, N, K):
"""Get split_k search range = [1] by default"""
space = [1]
# skip split-k search for rocm
if backend.target.Target.current().name() == "rocm":
return set(space)
factor = K // max(M, N)
low_range = max(1, factor // 4)
high_range = min(factor, 32)
if low_range == 1:
low_range += 1
space += list(range(low_range, high_range, 2))
_LOGGER.debug(
f"profiling split-k for gemm instance M={M}, N={N}, K={K} in {set(space)}",
)
return set(space)
def _get_ab_alignment(self, exec_key):
if self._attrs["op"].startswith("group_gemm"):
all_m, all_n, all_k = group_gemm_inverse_key_func(exec_key)
all_ab_alignments = [
self._attrs["f_ab_alignment"](int(m), int(n), int(k))
for m, n, k in zip(all_m, all_n, all_k)
]
ab_alignment = min(all_ab_alignments)
else:
# exec_key may contain batch dimension, which we don't care here
m, n, k = gemm_inverse_key_func(exec_key)[-3:]
ab_alignment = self._attrs["f_ab_alignment"](m, n, k)
if not alignment.valid_alignment(
ab_alignment, self._attrs["inputs"][0].dtype()
):
raise RuntimeError(
f"A / B {ab_alignment=} is not valid! The last dimension of each input tensor needs to be divisible by 2."
f"m: {m}, n: {n}, k: {k}."
)
return ab_alignment
def _profile_single_workload(
self, profiler_prefix, exec_key, profiler_runner, force_cache
):
"""
Schedule profilers for given profiler path and gemm shape (exec_key)
or get the result from cache
or use dummy result in CI
"""
target = backend.target.Target.current()
tmp_key = next(iter(self._attrs["op_instance"].keys()))
tmp_op = self._attrs["op_instance"][tmp_key]
exec_entry_sha1 = sha1(exec_key.encode("utf-8")).hexdigest()
split_k = 1 if self._attrs["split_k"] is None else self._attrs["split_k"]
# Because we call gen_profiler to generate and compile all profilers
# before running any of them, we won't be able to update the exec_path
# in gen_profiler even if two gemms have the same problem size (assume that
# we don't have a cache entry for this problem size). Consequently,
# we still need to query the cache here to ensure we won't re-profile
# the second gemm with the same problem size. Note that if we already
# have a cache entry for the problem size before gen_profiler, we will
# setup exec_path correctly in gen_profiler, so we won't get here at all.
query = GemmQueryEntry(
# 1 is subtracted from the type enum values for consistency with the existing
# cache databases; due to the "void" type being added to the DataType enum as
# the very first enum member (and shifting the values of other enum members) in
# https://github.com/NVIDIA/cutlass/commit/7c04f954151f606e60608061e891785fba229ae2
dtype_a=tmp_op.A.element.value - 1,
dtype_b=tmp_op.B.element.value - 1,
dtype_c=tmp_op.C.element.value - 1,
dtype_acc=tmp_op.accumulator_type().value - 1,
major_a=tmp_op.A.layout.value,
major_b=tmp_op.B.layout.value,
major_c=tmp_op.C.layout.value,
op_type=self._attrs["op"],
device=target._arch,
epilogue=tmp_op.epilogue_functor.value,
exec_entry_sha1=exec_entry_sha1,
pshape=self._attrs["permute_shape"],
)
cache_value = target.query_profile_cache("gemm", query.__dict__)
if cache_value is not None and not target.force_profile():
_LOGGER.debug(
f'Load profiling result for {self._attrs["name"]} '
f"from cache: {cache_value}",
)
self._attrs["exec_path"][exec_key].algo = cache_value[0]
self._attrs["workspace"] = max(self._attrs["workspace"], cache_value[1])
self._attrs["split_k"] = cache_value[2]
return
if cache_value is None and force_cache:
op_type = self._attrs["op"]
raise RuntimeError(
"force_cache is enabled but we could not find the following cache ",
f"available on device {target._arch=}, {op_type=}, {exec_entry_sha1=}",
)
if target.use_dummy_profiling_results():
op_type = self._attrs["op"]
raise Exception(
"This is a CI run but we could not find the following cache ",
f"available on device {target._arch}\n",
f"{op_type} {exec_entry_sha1}.\n",
"Please adjust target.select_minimal_algo function.",
)
profiler_filename = self._get_profiler_filename()
def _gen_callback(split_k):
def process_result_callback(result, postprocessing_delegate):
postprocessing_delegate.add_instance(
(result, self._attrs, profiler_filename, exec_key, split_k)
)
return process_result_callback
command = self._gen_profile_cmd(profiler_prefix, profiler_filename, exec_key)
if self._attrs["op"].startswith("group_gemm") or self._attrs["op"].startswith(
"bmm"
):
profiler_runner.push(command, _gen_callback(split_k=1))
else:
m, n, k = gemm_inverse_key_func(exec_key)[-3:]
if "split_k_hints" in self._attrs:
split_k_search_space = self._attrs["split_k_hints"]
else:
split_k_search_space = self._split_k_search_space(m, n, k)
for split_k in split_k_search_space:
gemm_command = command + [str(split_k)]
profiler_runner.push(gemm_command, _gen_callback(split_k))
def profile(
self,
profiler_runner,
workdir="./",
):
"""Selects the fastest kernel configurations.
Parameters
----------
profiler_runner: ProfilerRunner
Profiler runner to schedule async profiler jobs,
workdir : str
Base dir to keep profiling source codes, by default "./"running on separate GPU devices concurrently
"""
workloads = list(self._attrs["exec_path"].keys())
profiler_prefix = os.path.join(workdir, "profiler", self._attrs["op"])
if "op_instance" not in self._attrs:
target = backend.target.Target.current()
# init candidate ops
func_key = "{target}.{op}.config".format(
target=target.name(),
op=self._attrs["op"],
)
func = registry.get(func_key)
func(self._attrs, dtype=self._attrs["inputs"][0]._attrs["dtype"])
target = backend.target.Target.current()
force_cache = environ.force_profiler_cache()
for wkl in workloads:
_LOGGER.info(
"Profile: {name}: {wkl}".format(name=self._attrs["name"], wkl=wkl),
)
# if in CI just choose minimal configs
# workspace is a hack just provides 102400 Byte
if target.use_dummy_profiling_results() and not force_cache:
algo = target.select_minimal_algo(
list(self._attrs["op_instance"].keys())
)
_LOGGER.info(f"Select minimal algo {algo} for CI")
self._attrs["exec_path"][wkl].algo = algo
self._attrs["workspace"] = 102400
elif self._attrs["exec_path"][wkl].algo != "":
# we have cached best algo
return
else:
self._profile_single_workload(
profiler_prefix, wkl, profiler_runner, force_cache
)
def gen_function(self) -> str:
"""Generates the function code for the gemm op for the current target.
Returns
-------
str
C++ source code of the function
"""
target = backend.target.Target.current()
func_key = "{target}.{op}.gen_function".format(
target=target.name(), op=self._attrs["op"]
)
func = registry.get(func_key)
return func(
self._attrs,
self.exec_cond_template,
self._extract_dims(),
)
def _signature(self) -> str:
"""Generate the unique signature of the gemm op.
Returns
-------
str
The unique signature of the gemm op.
"""
op_name = self._attrs["op"] + ("split_" + str(self._attrs["split_k"]))
signature = sha1(op_name.encode("utf-8")).hexdigest()
return signature
def _align_ab(self, a: Tensor, b: Tensor):
return a, b
def _sanity_check(self, a: Tensor, b: Tensor):
a_shapes = a._attrs["shape"]
if len(a_shapes) < 2:
raise RuntimeError(
"gemm operand A should have >= 2 dimensions! Current shape: {}.".format(
a_shapes
)
)
b_shapes = b._attrs["shape"]
if len(b_shapes) != 2:
raise RuntimeError(
"gemm operand B should have 2 dimensions! Current shape: {}.".format(
b_shapes
)
)
if not is_same_dtype(a.dtype(), b.dtype()):
raise RuntimeError(
"gemm operand A and B should have the same data type! Current A: {atype}, B: {btype}.".format(
atype=a.dtype(), btype=b.dtype()
)
)
def __call__(self, a: Tensor, b: Tensor) -> Tensor:
"""Call the gemm op.
Parameters
----------
a : Tensor
Tensor with correct shape for the gemm operand A.
b : Tensor
Tensor with correct shape for the gemm operand B.
Returns
-------
Tensor
Output tensor for the gemm operation.
"""
a, b = self._align_ab(a, b)
self._attrs["inputs"] = [a, b]
# TensorAccessor(b) is for bmm or rare cases of gemm where b is not constant weight
self._attrs["input_accessors"] = [TensorAccessor(a), TensorAccessor(b)]
self._set_depth()
self._sanity_check(a, b)
output_shape = self._infer_shapes(a, b)
self._extract_epilogue_alignment(output_shape)
output = Tensor(output_shape, src_ops={self}, dtype=a.dtype())
self._attrs["outputs"] = [output]
self._attrs["output_accessors"] = [TensorAccessor(output)]
return output
def _profiler_results_groupby_key(instance):
return (
instance[1]["name"], # unique op name
instance[2], # profiler executable
instance[3], # profiler key (gemm shape)
)
def _profiler_group_reduce_min_key(group):
return group[0][1] # elapsed runtime
[docs]class GemmProfilerPostprocessingDelegate:
"""
Object which collects profiler results after profiler executables complete,
updates profiler results cache and the gemm nodes' attrs after all profilers complete.
"""
def __init__(self):
"""
Initialize storage for profiler results
Instance=(
ProfileResult=(best_algo, elapsed_runtime, workspace),
func_attrs,
profiler_filename,
exec_key,
split_k,
)
"""
self._instances = []
[docs] def add_instance(self, instance: ProfileResult):
"""
As a profiler executable completes, collect the result
"""
self._instances.append(instance)
[docs] def postprocess_results(self):
"""
When all profiler executables complete, find the best instance
(min runtime per op name, profiler executable and exec_key (i.e. gemm shape mnk)
across multiple split_k values)
The best instance is cached, and written into corresponding gemm nodes in the graph
"""
target = backend.target.Target.current()
for _, group in itertools.groupby(
self._instances,
key=_profiler_results_groupby_key,
):
min_runtime_results = min(group, key=_profiler_group_reduce_min_key)
(
(best_algo, runtime, workspace),
func_attrs,
profiler_filename,
exec_key,
split_k,
) = min_runtime_results
func_attrs["exec_path"][exec_key].algo = best_algo
func_attrs["workspace"] = max(func_attrs["workspace"], workspace)
func_attrs["split_k"] = split_k
_LOGGER.info(
f"Profiler ({profiler_filename} {exec_key}) selected kernel: "
f"{best_algo=} {workspace=} {split_k=}",
)
tmp_op = next(iter(func_attrs["op_instance"].values()))
exec_entry_sha1 = sha1(exec_key.encode("utf-8")).hexdigest()
cache_record = GemmRecordEntry(
exec_entry=exec_key,
exec_entry_sha1=exec_entry_sha1,
# 1 is subtracted from the type enum values for consistency with the existing
# cache databases; due to the "void" type being added to the DataType enum as
# the very first enum member (and shifting the values of other enum members) in
# https://github.com/NVIDIA/cutlass/commit/7c04f954151f606e60608061e891785fba229ae2
dtype_a=tmp_op.A.element.value - 1,
dtype_b=tmp_op.B.element.value - 1,
dtype_c=tmp_op.C.element.value - 1,
dtype_acc=tmp_op.accumulator_type().value - 1,
major_a=tmp_op.A.layout.value,
major_b=tmp_op.B.layout.value,
major_c=tmp_op.C.layout.value,
op_type=func_attrs["op"],
epilogue=tmp_op.epilogue_functor.value,
device=target._arch,
algo=best_algo,
workspace=workspace,
split_k=split_k,
pshape=func_attrs["permute_shape"],
)
try:
target.insert_profile_cache("gemm", cache_record.__dict__)
except Exception as e:
_LOGGER.warning(e)