# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Base class for conv2d.
"""
import itertools
import logging
import os
import re
from collections import OrderedDict
from hashlib import sha1
from operator import itemgetter
from typing import Any, Dict, List, Tuple, Union
import jinja2
from aitemplate import backend
from aitemplate.backend import registry
from aitemplate.backend.target import Target
from aitemplate.compiler.base import (
DynamicProfileStrategy,
IntImm,
IntVar,
Operator,
Tensor,
)
from aitemplate.compiler.ops.conv.cache_entry import ConvQueryEntry, ConvRecordEntry
from aitemplate.compiler.ops.conv.conv_common import (
filter_op_instances,
generate_profiler_sources,
get_profiler_filename,
)
from aitemplate.utils import alignment, environ, shape_utils
# pylint: disable=C0103,W0221,R1732,W0102,W1202,C0301,R1716
_LOGGER = logging.getLogger(__name__)
SHAPE_FUNC_TEMPLATE = jinja2.Template(
"""
{{indent}}{{dtype}}NI = {{x_dim0}};
{{indent}}{{dtype}}HI = {{x_dim1}};
{{indent}}{{dtype}}WI = {{x_dim2}};
{{indent}}{{dtype}}CI = {{x_dim3}};
{{indent}}{{dtype}}CO = {{w_dim0}};
{{indent}}{{dtype}}KH = {{w_dim1}};
{{indent}}{{dtype}}KW = {{w_dim2}};
{{indent}}{{dtype}}SH = {{strideh}};
{{indent}}{{dtype}}SW = {{stridew}};
{{indent}}{{dtype}}DH = {{dilateh}};
{{indent}}{{dtype}}DW = {{dilatew}};
{{indent}}{{dtype}}PH = {{padh}};
{{indent}}{{dtype}}PW = {{padw}};
{{indent}}{{dtype}}KHEff = (KH - 1) * DH + 1;
{{indent}}{{dtype}}KWEff = (KW - 1) * DW + 1;
{{indent}}{{dtype}}NO = NI;
{{indent}}{{dtype}}HO = (HI + PH + PH - KHEff) {{div}} SH + 1;
{{indent}}{{dtype}}WO = (WI + PW + PW - KWEff) {{div}} SW + 1;
"""
)
SHAPE_ASSIGNMENT_TEMPLATE = jinja2.Template(
"""
{{indent}}{{y_dim0}} = NO;
{{indent}}{{y_dim1}} = HO;
{{indent}}{{y_dim2}} = WO;
{{indent}}{{y_dim3}} = CO;
"""
)
EXEC_KEY_TEMPLATE = jinja2.Template(
"""
NI == {{x_dim0}} && HI == {{x_dim1}} && WI == {{x_dim2}} && CI == {{x_dim3}}
"""
)
EXEC_DYN_KEY_TEMPLATE = jinja2.Template(
"""
NI >= {{x_dim0_lb}} && NI <= {{x_dim0_ub}} &&
HI >= {{x_dim1_lb}} && HI <= {{x_dim1_ub}} &&
WI >= {{x_dim2_lb}} && WI <= {{x_dim2_ub}} &&
CI == {{x_dim3}}
"""
)
EXEC_COND_TEMPLATE = jinja2.Template(
"""
{{indent}}if ({{cond}}) {
{{indent}} {{program}}
{{indent}}}
"""
)
[docs]class conv2d(Operator):
r"""
Applies a 2D convolution on input with size (N, H, W, C_in), and produces output with size (N, H_out, W_out, C_out) where N is batch size, H, W are the height and width of the image in pixels, and C is the number of channels.
In the simplest case, the output value of the layer with input size
:math:`(N, H, W, C_{\text{in}})` and output :math:`(N, H_{\text{out}}, W_{\text{out}}, C_{\text{out}})`
can be precisely described as:
.. math::
\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
\sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
where :math:`\star` is the valid 2D `cross-correlation`_ operator.
* :attr:`stride` controls the stride for the cross-correlation.
* :attr:`pad` controls the amount of implicit zero padding on both
sides for ``dilation * (kernel_size - 1) - padding`` number of points.
* :attr:`dilate` controls the spacing between the kernel points; also known as the à trous algorithm.
It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does.
* :attr:`group` controls the number of blocked connections from input channels
to output channels.
Args:
input: input tensor of shape :math:`(N , H , W, \text{in\_channels})`
weight: filters of shape :math:`(\text{out\_channels} , K_h, K_w, \frac{\text{in\_channels}}{\text{groups}})`
This operator uses "channels_last" data format. Below is an example and its equivalence in PyTorch:
.. highlight:: python
.. code-block:: python
X = Tensor(shape=[N, H, W, C_in], dtype="float16", name="images", is_input=True)
W = Tensor(shape=[C_out, K_h, K_w, C_in], dtype="float16", name="weight", is_input=True)
OP = aitemplate.compiler.ops.conv2d(stride=1, pad=1, dilate=1)
Y = OP(X, W)
.. highlight:: python
.. code-block:: python
X_pt = NHWC2NCHW(X_ait)
W_pt = NHWC2NCHW(W_ait)
Y_pt = torch.nn.functional.conv2d(X_pt, W_pt)
Y = NCHW2NHWC(Y_pt)
.. _cross-correlation:
https://en.wikipedia.org/wiki/Cross-correlation
.. _`here`:
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
def __init__(self, stride, pad, dilate=1, group=1) -> None:
"""Conv2d constructor.
Parameters
----------
stride : int or tuple of two ints
Stride of the convolution. If tuple is
provided, the elements correspond to height and width stride
respectively
pad : int or tuple of two ints
Size of padding to add to the input. If tuple is
provided, the elements correspond to height and width padding
respectively
dilate : int or tuple of two ints, optional
Size of spacing between kernel elements, by default 1. If tuple is
provided, the elements correspond to height and width dilation
respectively
group : int, optional
Number of blocked connections from input
channels to output channels, by default 1
"""
super().__init__()
self._attrs["op"] = "conv2d"
self._attrs["stride"] = stride
self._attrs["pad"] = pad
self._attrs["dilate"] = dilate
self._attrs["group"] = group
self._attrs["has_profiler"] = True
self._attrs["epilogue_alignment"] = 1
self._attrs["epilogue"] = "LinearCombination"
self._attrs["workspace"] = 0
self._attrs["split_k"] = None
self.shape_eval_template = SHAPE_FUNC_TEMPLATE
self.shape_save_template = SHAPE_ASSIGNMENT_TEMPLATE
self.exec_key_template = EXEC_KEY_TEMPLATE
self.exec_dyn_key_template = EXEC_DYN_KEY_TEMPLATE
self.exec_cond_template = EXEC_COND_TEMPLATE
def _get_params_factory(self):
params_factory = {}
# Ensure convolutional parameters are in form (val_h, val_w)
params_factory["strideh"], params_factory["stridew"] = _maybe_int_to_tuple(
self._attrs["stride"],
"Stride",
)
params_factory["padh"], params_factory["padw"] = _maybe_int_to_tuple(
self._attrs["pad"],
"Pad",
)
params_factory["dilateh"], params_factory["dilatew"] = _maybe_int_to_tuple(
self._attrs["dilate"],
"Dilation",
)
return params_factory
def _infer_shape(self, x: List[int], w: List[int]) -> List[int]:
if x[3] != w[3] * self._attrs["group"]:
raise RuntimeError("X/W Shape mismatch for conv2d")
eval_func = self.shape_eval_template.render(
indent="",
dtype="",
div="//",
x_dim0=x[0],
x_dim1=x[1],
x_dim2=x[2],
x_dim3=x[3],
w_dim0=w[0],
w_dim1=w[1],
w_dim2=w[2],
**self._get_params_factory(),
)
output = {}
exec(eval_func, output) # noqa: P204
return [
int(output["NO"]),
int(output["HO"]),
int(output["WO"]),
int(output["CO"]),
]
def _infer_shapes(self, x: Tensor, w: Tensor) -> List[int]:
x_shape_values = [var._attrs["values"] for var in x._attrs["shape"]]
x_shapes = itertools.product(*x_shape_values)
w_shape = [var._attrs["values"][0] for var in w._attrs["shape"]]
self._attrs["CO"] = w_shape[0]
self._attrs["KH"] = w_shape[1]
self._attrs["KW"] = w_shape[2]
# run infershape for each
y_shapes = []
for x_shape in x_shapes:
y_shape = self._infer_shape(x_shape, w_shape)
y_shapes.append(y_shape)
def unique(vector):
return sorted(set(vector))
output_shape = [
x._attrs["shape"][0],
shape_utils.gen_int_var(unique([d[1] for d in y_shapes])),
shape_utils.gen_int_var(unique([d[2] for d in y_shapes])),
shape_utils.gen_int_var(unique([d[3] for d in y_shapes])),
]
in_h = x._attrs["shape"][1]._attrs["symbolic_value"]
in_w = x._attrs["shape"][2]._attrs["symbolic_value"]
# Ensure convolutional parameters are in form (val_h, val_w)
dilate_h, dilate_w = _maybe_int_to_tuple(self._attrs["dilate"], "Dilation")
stride_h, stride_w = _maybe_int_to_tuple(self._attrs["stride"], "Stride")
pad_h, pad_w = _maybe_int_to_tuple(self._attrs["pad"], "Pad")
KHEff = (w_shape[1] - 1) * dilate_h + 1
KWEff = (w_shape[2] - 1) * dilate_w + 1
out_h = (in_h + 2 * pad_h - KHEff) // stride_h + 1
out_w = (in_w + 2 * pad_w - KWEff) // stride_w + 1
output_shape[1]._attrs["symbolic_value"] = out_h
output_shape[2]._attrs["symbolic_value"] = out_w
return output_shape
def _invert_exec_key(self, key):
tmp = re.findall(r"(\d+)", key)
return [int(x) for x in tmp]
def _gen_exec_key(self, shape: List[int]):
return self.exec_key_template.render(
x_dim0=shape[0], x_dim1=shape[1], x_dim2=shape[2], x_dim3=shape[3]
).replace("\n", "")
def _gen_dyn_exec_key(
self, dim0_lb, dim0_ub, dim1_lb, dim1_ub, dim2_lb, dim2_ub, dim3
):
return self.exec_dyn_key_template.render(
x_dim0_lb=dim0_lb,
x_dim0_ub=dim0_ub,
x_dim1_lb=dim1_lb,
x_dim1_ub=dim1_ub,
x_dim2_lb=dim2_lb,
x_dim2_ub=dim2_ub,
x_dim3=dim3,
).replace("\n", "")
def _extract_exec_path(self, x: Tensor):
x_shape_values = [var._attrs["values"] for var in x._attrs["shape"]]
# FIXME: we take the max height and weight for profiling at the moment.
# Let's figure out a better profiling strategy later.
# The following attribute is temporarily used to hold the lower bounds of
# all dimensions. We will remove them later once we have a better profiling
# strategy.
self._attrs["dim_lower_bounds"] = [min(vals) for vals in x_shape_values]
x_shape_values = [x_shape_values[0]] + [[max(vs)] for vs in x_shape_values[1:]]
x_shapes = itertools.product(*x_shape_values)
self._attrs["exec_path"] = OrderedDict()
for x_shape in x_shapes:
key = self._gen_exec_key(x_shape)
self._attrs["exec_path"][key] = ""
def _extract_epilogue_alignment(self, output_shape: List[IntVar]) -> None:
epilogue_dim = output_shape[-1]
if not isinstance(epilogue_dim, IntImm):
raise RuntimeError("Conv output last dimension must be static!")
self._attrs["epilogue_alignment"] = alignment.find_max_alignment(
number=epilogue_dim._attrs["values"][0],
dtype=self._attrs["inputs"][0]._attrs["dtype"],
)
def __call__(self, x: Tensor, w: Tensor) -> List[Tensor]:
"""Call conv2d with tensors x, w
Parameters
----------
x : Tensor
in shape (N, H, W, C_in)
w : Tensor
in shape (C_out, K_h, K_w, C_in)
Returns
-------
List[Tensor]
includes the output tensor in shape (N, H_out, W_out, C_out)
"""
self._attrs["inputs"] = [x, w]
self._set_depth()
output_shape = self._infer_shapes(x, w)
self._extract_exec_path(x)
self._extract_epilogue_alignment(output_shape)
output = Tensor(output_shape, src_ops={self}, dtype=x._attrs["dtype"])
self._attrs["outputs"] = [output]
return output
def _get_op_attributes(self) -> Dict[str, Any]:
target_attrs = ["dilate", "group", "pad", "stride"]
attr = {}
for target_attr in target_attrs:
if target_attr in self._attrs:
attr[target_attr] = self._attrs[target_attr]
return attr
def _should_build_profiler(self) -> bool:
"""
Check if we should build profilers. If we have a cached
entry for this conv instance, we update this conv op's
relevant attributes with the cached result and return False.
"""
force_cache = environ.force_profiler_cache()
if self._has_dynamic_input_dims():
if force_cache:
raise RuntimeError(
"We cannot force to use the cache as dynamic dims require "
"us to generate and build the profilers"
)
# If there are dynamic dims, we'll have to generate and build the
# profilers, as the binaries will be needed for dynamic profiling.
return True
# We are forced to use the cache so we skip building profilers.
if force_cache:
return False
target = backend.target.Target.current()
workloads = list(self._attrs["exec_path"].keys())
build_profiler = True
# Now, let's query if all of our workloads have cache entries. If that
# is the case, it is safely to skip generating and building profilers.
if not target.use_dummy_profiling_results():
tmp_key = next(iter(self._attrs["op_instance"].keys()))
tmp_op = self._attrs["op_instance"][tmp_key]
build_profiler = False
for wkl in workloads:
exec_entry_sha1 = sha1(wkl.encode("utf-8")).hexdigest()
split_k = (
1 if self._attrs["split_k"] is None else self._attrs["split_k"]
)
query = ConvQueryEntry(
dtype_a=tmp_op.A.element.value - 1,
dtype_b=tmp_op.B.element.value - 1,
dtype_c=tmp_op.C.element.value - 1,
dtype_acc=tmp_op.accumulator_type().value - 1,
major_a=tmp_op.A.layout.value,
major_b=tmp_op.B.layout.value,
major_c=tmp_op.C.layout.value,
kh=self._attrs["KH"],
kw=self._attrs["KW"],
co=self._attrs["CO"],
op_type=self._attrs["op"],
device=target._arch,
epilogue=tmp_op.epilogue_functor.value,
split_k=split_k,
exec_entry_sha1=exec_entry_sha1,
**self._get_params_factory(),
)
cache_value = target.query_profile_cache("conv", query.__dict__)
if cache_value is not None and not target.force_profile():
_LOGGER.info(
f'Load profiling result for {self._attrs["name"]} '
f"from cache: {cache_value}",
)
best_algo, workspace = cache_value
self._attrs["exec_path"][wkl] = best_algo
self._attrs["workspace"] = max(self._attrs["workspace"], workspace)
else:
# cache miss - we will have to generate and build profilers
build_profiler = True
return build_profiler
[docs] def gen_profiler(
self,
workdir: str = None,
dynamic_profiling_strategy=DynamicProfileStrategy.HINTS,
) -> None:
"""Profiler generator.
Parameters
----------
workdir : str, optional, by default None
dynamic_profiling_strategy: DynamicProfileStrategy, optional
A dynamic profiling strategy, used to filter generated profiles at compile time.
See also: :func:`~aitemplate.compiler.transform.profile.profile`
"""
target = backend.target.Target.current()
func_key = "{target}.{op}.config".format(
target=target.name(), op=self._attrs["op"]
)
func = registry.get(func_key)
func(self._attrs, dtype=self._attrs["inputs"][0]._attrs["dtype"])
if self._should_build_profiler():
x_shapes = [
self._invert_exec_key(exec_key) for exec_key in self._attrs["exec_path"]
]
self._attrs["op_instance"] = filter_op_instances(
func_attrs=self._attrs,
x_shapes=x_shapes,
)
return generate_profiler_sources(
func_attrs=self._attrs,
op_class="conv",
workdir=workdir,
shape_template=self.shape_eval_template,
)
def _gen_profile_cmd(self, profiler_prefix, cfg, x_shape):
exe_path = os.path.join(profiler_prefix, cfg)
if not os.access(exe_path, os.X_OK):
raise RuntimeError("Profiler %s is not executable" % exe_path)
cmd = [exe_path]
params = self._get_params_factory()
cmd.append(x_shape[0])
cmd.append(x_shape[1])
cmd.append(x_shape[2])
cmd.append(x_shape[3])
cmd.append(self._attrs["KH"])
cmd.append(self._attrs["KW"])
cmd.append(self._attrs["CO"])
cmd.append(params["strideh"])
cmd.append(params["padh"])
cmd.append(params["dilateh"])
cmd.append(params["stridew"])
cmd.append(params["padw"])
cmd.append(params["dilatew"])
cmd.append(self._attrs["group"])
command = [str(x) for x in cmd]
return command
def _profile_single_workload(self, profiler_prefix, exec_key, devices, force_cache):
target = backend.target.Target.current()
# query cache
tmp_key = next(iter(self._attrs["op_instance"].keys()))
tmp_op = self._attrs["op_instance"][tmp_key]
exec_entry_sha1 = sha1(exec_key.encode("utf-8")).hexdigest()
split_k = 1 if self._attrs["split_k"] is None else self._attrs["split_k"]
query = ConvQueryEntry(
dtype_a=tmp_op.A.element.value - 1,
dtype_b=tmp_op.B.element.value - 1,
dtype_c=tmp_op.C.element.value - 1,
dtype_acc=tmp_op.accumulator_type().value - 1,
major_a=tmp_op.A.layout.value,
major_b=tmp_op.B.layout.value,
major_c=tmp_op.C.layout.value,
kh=self._attrs["KH"],
kw=self._attrs["KW"],
co=self._attrs["CO"],
op_type=self._attrs["op"],
device=target._arch,
epilogue=tmp_op.epilogue_functor.value,
split_k=split_k,
exec_entry_sha1=exec_entry_sha1,
**self._get_params_factory(),
)
cache_value = target.query_profile_cache("conv", query.__dict__)
if cache_value is not None and not target.force_profile():
_LOGGER.info("Load profiling result from cache.")
return cache_value
if cache_value is None and force_cache:
op_type = self._attrs["op"]
raise RuntimeError(
"force_cache is enabled but we could not find the following cache ",
f"available on device {target._arch=}, {op_type=}, {exec_entry_sha1=}",
)
if target.use_dummy_profiling_results():
op_type = self._attrs["op"]
raise Exception(
"This is a CI run but we could not find the following cache ",
f"available on device {target._arch}\n",
f"{op_type} {exec_entry_sha1}.\n",
"Please adjust target.select_minimal_algo function.",
)
profiler_filename = get_profiler_filename(self._attrs, "conv")
runner = backend.profiler_runner.Runner(
devices, self._attrs["name"], timeout=180
)
x_shape = self._invert_exec_key(exec_key)
command = self._gen_profile_cmd(profiler_prefix, profiler_filename, x_shape)
runner.push(profiler_filename, command)
runner.join()
result = runner.pull()
if len(result) == 0:
raise RuntimeError(
"Profile workload: " f"{exec_key}" " failed. " f"Results: {result}."
)
out = min(result, key=itemgetter(1))
best_algo = out[1].op_config
workspace = out[1].workspace
## cache
cache_record = ConvRecordEntry(
exec_entry=exec_key,
exec_entry_sha1=exec_entry_sha1,
dtype_a=tmp_op.A.element.value - 1,
dtype_b=tmp_op.B.element.value - 1,
dtype_c=tmp_op.C.element.value - 1,
dtype_acc=tmp_op.accumulator_type().value - 1,
major_a=tmp_op.A.layout.value,
major_b=tmp_op.B.layout.value,
major_c=tmp_op.C.layout.value,
kh=self._attrs["KH"],
kw=self._attrs["KW"],
co=self._attrs["CO"],
op_type=self._attrs["op"],
epilogue=tmp_op.epilogue_functor.value,
device=target._arch,
algo=best_algo,
workspace=workspace,
split_k=split_k, # todo add into profile
**self._get_params_factory(),
)
Target.current().insert_profile_cache("conv", cache_record.__dict__)
return (best_algo, workspace)
def _has_dynamic_input_dims(self):
for input_tensor in self._attrs["inputs"]:
for dim in input_tensor._attrs["shape"]:
if not isinstance(dim, IntImm):
return True
return False
[docs] def profile(
self,
workdir="./",
devices=None,
dynamic_profiling_strategy=DynamicProfileStrategy.HINTS,
):
if devices is None:
devices = [0]
self._profile_static(workdir, devices)
if self._has_dynamic_input_dims():
if dynamic_profiling_strategy != DynamicProfileStrategy.HINTS:
raise NotImplementedError(
"conv2d only supports HINTS dynamic profiling strategy for now! Current strategy: {}".format(
dynamic_profiling_strategy
)
)
self._profile_dynamic_dim(workdir)
def _profile_static(self, workdir, devices):
"""Profiles with static shapes."""
workloads = list(self._attrs["exec_path"].keys())
profiler_prefix = os.path.join(workdir, "profiler", self._attrs["op"])
target = backend.target.Target.current()
if "op_instance" not in self._attrs:
# init candidate ops
func_key = "{target}.{op}.config".format(
target=target.name(), op=self._attrs["op"]
)
func = registry.get(func_key)
func(self._attrs, dtype=self._attrs["inputs"][0]._attrs["dtype"])
force_cache = environ.force_profiler_cache()
for wkl in workloads:
_LOGGER.info(
"Profile: {name}: {wkl}".format(name=self._attrs["name"], wkl=wkl),
)
# if in CI just choose minimal configs
# workspace is a hack just provides 102400 Byte
if target.use_dummy_profiling_results() and not force_cache:
algo = target.select_minimal_algo(
list(self._attrs["op_instance"].keys())
)
_LOGGER.info(f"Select minimal algo {algo} for CI")
self._attrs["exec_path"][wkl] = algo
self._attrs["workspace"] = 102400
elif self._attrs["exec_path"][wkl] == "":
best_algo, workspace = self._profile_single_workload(
profiler_prefix, wkl, devices, force_cache
)
self._attrs["exec_path"][wkl] = best_algo
self._attrs["workspace"] = max(self._attrs["workspace"], workspace)
def _profile_dynamic_dim(self, workdir):
"""Profiles with dynamic shapes."""
# extract dynamic dim from exec_path
def _extract_dynamic_dim(exec_keys):
var_dims = [[], [], [], []]
for key in exec_keys:
dims = self._invert_exec_key(key)
for i, v in enumerate(dims):
var_dims[i].append(v)
return var_dims
dim_lbs = self._attrs["dim_lower_bounds"]
dims = _extract_dynamic_dim(self._attrs["exec_path"].keys())
dim0_lb = dim_lbs[0]
dim1_lb = dim_lbs[1]
dim2_lb = dim_lbs[2]
# dims' upper bounds are the same except the batch dimension
dim1_ub = dims[1][0]
dim2_ub = dims[2][0]
dim3 = dims[3][0]
num_exec_path = len(self._attrs["exec_path"])
if num_exec_path < 1:
return
algos = list(self._attrs["exec_path"].values())
if num_exec_path == 1 or len(set(algos)) <= 1:
# all exec paths point to the same algo
new_exec_paths = OrderedDict()
# Because we have a single algo, it's safe to just take the upper
# bound of dim0 (i.e. batch dim) values.
dim0_ub = max(dims[0])
# we need to generate new exec paths that ensure the ranges of
# likely dynamic heights and weights
new_key = self._gen_dyn_exec_key(
dim0_lb, dim0_ub, dim1_lb, dim1_ub, dim2_lb, dim2_ub, dim3
)
new_exec_paths[new_key] = algos[0]
self._attrs["exec_path"] = new_exec_paths
return
target = backend.target.Target.current()
if target.use_dummy_profiling_results():
return
profiler_prefix = os.path.join(workdir, "profiler", self._attrs["op"])
runner = backend.profiler_runner.Runner([0], self._attrs["name"])
# generate region
regions = [] # lb, ub, lb_algos, ub_algos
for i in range(len(dims[0]) - 1):
regions.append([dims[0][i], dims[0][i + 1], algos[i], algos[i + 1]])
# for each region,
# binary search to find cutting point
# generate new exec
special_cases = OrderedDict()
new_exec_paths = OrderedDict()
for lb, ub, lb_algo, ub_algo in regions:
mid = (lb + ub) // 2
origin_lb = lb
origin_ub = ub
last_mid = mid
while mid > lb and mid < ub:
mid = (lb + ub) // 2
mid_shape = [mid, dim1_ub, dim2_ub, dim3]
_LOGGER.info(
"current: lb_algo: {lb_algo}, LB:{lb} MID:{mid} UB:{ub}".format(
lb_algo=lb_algo, lb=lb, mid=mid, ub=ub
),
)
# run the profiler binary with all ops on the mid_shape
# and fetch the results only for the lb_algo and ub_algo
profiler_filename = get_profiler_filename(self._attrs, "conv")
profiler_cmd = self._gen_profile_cmd(
profiler_prefix, profiler_filename, mid_shape
)
runner.push(
idx=profiler_filename,
cmd=profiler_cmd,
return_ops=[str(lb_algo), str(ub_algo)],
)
runner.join()
result = runner.pull()
result_dict = {res.op_config: res for res in result[0][1]}
assert len(result_dict) >= 1
# if there is only one result, assume ub algo failed.
if len(result_dict) == 1:
assert str(ub_algo) not in result_dict
# last_lb = lb
lb = mid + 1
# if there are two result, compare to decide new lb/ub
else:
lb_time = result_dict[str(lb_algo)].duration
ub_time = result_dict[str(ub_algo)].duration
if lb_time < ub_time:
# lb algo can work with larger batch
# last_lb = lb
lb = mid + 1
else:
# ub algo can work with smaller batch
# last_ub = ub
ub = mid - 1
last_mid = mid
mid = (lb + ub) // 2
lo_region_key = self._gen_dyn_exec_key(
origin_lb, last_mid, dim1_lb, dim1_ub, dim2_lb, dim2_ub, dim3
)
up_region_key = self._gen_dyn_exec_key(
last_mid, origin_ub, dim1_lb, dim1_ub, dim2_lb, dim2_ub, dim3
)
new_exec_paths[lo_region_key] = lb_algo
new_exec_paths[up_region_key] = ub_algo
# find special cases
# This code is kept in case need fully tested dynamic code
# So far I find binary search works well.
# def _find_special_case(lb, ub, algo):
# for i in range(lb + 1, ub + 1):
# x_shape = [i, dim1, dim2, dim3]
# cmd = self._gen_profile_cmd(profiler_prefix, str(algo), x_shape)
# runner.push(0, cmd)
# runner.join()
# out = runner.pull()
# if len(out) == 0:
# _LOGGER.info("Find specail case: batch=%d" % i)
# algo = self._profile_single_workload(profiler_prefix, x_shape, [0])
# special_cases[self._gen_exec_key(x_shape)] = algo
# _LOGGER.info(
# "Searching for specail cases between [{lb}, {ub}]".format(lb=origin_lb,
# ub=last_mid))
# _find_special_case(origin_lb, last_mid, lb_algo)
# _LOGGER.info(
# "Searching for specail cases between [{lb}, {ub}]".format(lb=last_mid + 1,
# ub=origin_ub))
# _find_special_case(last_mid, origin_ub, ub_algo)
special_cases.update(new_exec_paths)
self._attrs["exec_path"] = special_cases
[docs] def gen_function(self) -> str:
target = backend.target.Target.current()
func_key = "{target}.{op}.gen_function".format(
target=target.name(), op=self._attrs["op"]
)
func = registry.get(func_key)
return func(
self._attrs,
self.exec_cond_template,
self.shape_eval_template,
self.shape_save_template,
)
def _maybe_int_to_tuple(x: Union[int, Tuple[int, int]], name: str) -> Tuple[int, int]:
if isinstance(x, int):
return x, x
if isinstance(x, tuple) and len(x) == 2:
return x
raise ValueError(f"{name} should be either int or tuple of 2 ints, but got {x}")