Source code for aitemplate.compiler.ops.tensor.argmax

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
Argmax.
"""
import itertools
import logging
import os
import re
from collections import OrderedDict
from operator import itemgetter
from typing import List

import jinja2
import numpy as np

from aitemplate import backend
from aitemplate.backend import registry
from aitemplate.compiler.base import Operator, Tensor

# pylint: disable=C0103,W0221,W0102,W0223


_LOGGER = logging.getLogger(__name__)

EXEC_KEY_TEMPLATE = jinja2.Template(
    """
instance_size == {{x_dim0}} &&  instance_num == {{x_dim1}}
"""
)


[docs]class argmax(Operator): """ Returns the indices of the maximum value of all elements across a dimension in the input tensor. If there are multiple maximal values then the indices of the first maximal value are returned. Args: input (Tensor): the source tensor dim (int): optional, the dimension to reduce. Default: 0 Returns: Tensor: a long tensor that contains the indices of the maximum values """ def __init__(self, dim=0) -> None: """initialize the op""" super().__init__() self._attrs["op"] = "argmax" self._attrs["has_profiler"] = False self._attrs["dim"] = dim self._attrs["has_profiler"] = True self._attrs["workspace"] = 0 self.exec_key_template = EXEC_KEY_TEMPLATE def _infer_shapes(self, x: Tensor): """Infer the output shape""" return x._attrs["shape"][:-1] def __call__(self, x: Tensor) -> Tensor: """call the op Parameters ---------- x : Tensor input tensor Returns ---------- Tensor """ self._attrs["inputs"] = [x] self._set_depth() output_shape = self._infer_shapes(x) self._extract_exec_path(x) output = Tensor(output_shape, src_ops={self}, dtype="int64") self._attrs["outputs"] = [output] return output
[docs] def gen_function(self) -> str: """call backend function""" target = backend.target.Target.current() func_key = "{target}.{op}.gen_function".format( target=target.name(), op=self._attrs["op"] ) func = registry.get(func_key) return func(self._attrs)
[docs] def gen_profiler( self, workdir: str = None, dynamic_profiling_strategy=None ) -> None: target = backend.target.Target.current() func_key = "{target}.{op}.gen_profiler".format( target=target.name(), op=self._attrs["op"] ) func = registry.get(func_key) return func(self._attrs, workdir)
def _gen_exec_key(self, shape: List[int]): """rending the shape info""" elem_cnt = np.prod(shape) instance_size = shape[-1] instance_num = elem_cnt // instance_size return self.exec_key_template.render( x_dim0=instance_size, x_dim1=instance_num ).replace("\n", "") def _extract_exec_path(self, x: Tensor): x_shape_values = [var._attrs["values"] for var in x._attrs["shape"]] x_shapes = itertools.product(*x_shape_values) self._attrs["exec_path"] = OrderedDict() for x_shape in x_shapes: key = self._gen_exec_key(x_shape) self._attrs["exec_path"][key] = "" def _invert_exec_key(self, key): tmp = re.findall(r"(\d+)", key) return [int(x) for x in tmp] def _gen_profile_cmd(self, profiler_prefix, cfg, x_shape): exe_path = os.path.join(profiler_prefix, cfg) if not os.access(exe_path, os.X_OK): raise RuntimeError("Profiler %s is not executable" % exe_path) cmd = [exe_path] cmd.append(x_shape[0]) cmd.append(x_shape[1]) command = [str(x) for x in cmd] _LOGGER.info("profiling cmd: {}".format(command)) return command def _profile_single_workload(self, profiler_prefix, exec_key, devices): runner = backend.profiler_runner.Runner(devices, self._attrs["name"]) cfg = self._attrs["op"] x_shape = self._invert_exec_key(exec_key) command = self._gen_profile_cmd(profiler_prefix, cfg, x_shape) runner.push(cfg, command) runner.join() result = runner.pull() if len(result) == 0: raise RuntimeError( "Profile workload: " f"{exec_key}" " failed. " f"Results: {result}." ) out = min(result, key=itemgetter(1)) workspace = out[1].workspace return workspace
[docs] def profile( self, workdir="./", devices=None, dynamic_profiling_strategy=None, ): """Get the Argmax Op workspace Parameters ---------- workdir : str, optional Base dir to keep profiling source codes, by default "./" devices: list, optional Devices used for profiling, by default device 0 will be used. dynamic_profiling_strategy: DynamicProfileStrategy, optional A dynamic profiling strategy. By default MAX is used, i.e. to profile a dynamic range, an upper bound will be used. """ if devices is None: devices = [0] workloads = list(self._attrs["exec_path"].keys()) profiler_prefix = os.path.join(workdir, "profiler", self._attrs["op"]) for wkl in workloads: _LOGGER.info( "Profile: {name}: {wkl}".format(name=self._attrs["name"], wkl=wkl), ) workspace = self._profile_single_workload(profiler_prefix, wkl, devices) self._attrs["workspace"] = max(self._attrs["workspace"], workspace)