# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Basic data types of AITemplate.
"""
from __future__ import annotations
import copy
import math
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from functools import reduce
from numbers import Number
from pprint import pformat
from typing import Any, Dict, Iterable, List, Optional, Set, Union
import numpy as np
import sympy
from aitemplate.compiler import symbolic
from aitemplate.compiler.dtype import get_dtype_size, normalize_dtype
from aitemplate.compiler.op_registry import OP_REGISTRY
from aitemplate.compiler.stable_set import StableSet
from aitemplate.utils.tensor_utils import wrap_dim
from aitemplate.utils.torch_utils import torch_dtype_to_string
# pylint: disable=C0206,W0613,C0201,W0102,W0231,W0233
[docs]class Node(ABC):
"""Base class of Tensor, Operator, etc."""
def __init__(self) -> None:
"""
Initializes self._attrs field, which is a dict that stores
all attributes for this Node.
Basic attributes include:
* name: str, name of the node.
* depth: int, depth of the node in a graph. None if this is not applicable.
* nop: bool, marks whether this node is a no-operation.
Child classes add their own attributes to this dict.
"""
super().__init__()
self._attrs: Dict[str, Any] = {"name": None, "depth": 0, "nop": False}
def __str__(self) -> str:
"""Returns a string version of this object."""
return pformat(self._attrs, indent=2, depth=2)
def __repr__(self) -> str:
"""Returns a string containing a printable representation of this object."""
return self.__str__()
[docs] @abstractmethod
def pseudo_code(self, with_shape: bool = False) -> str:
"""Returns a string containing pseudo code of this object.
Parameters
----------
with_shape: bool
Marks whether to include shape info in the returned pseudo code.
Returns
----------
str
Pseudo code.
"""
pass
[docs]class IntVar(Node):
"""
An IntVar represents a dynamic dimension.
IntVar and IntImm (see below) are used together to represent a Tensor's shape.
IntVar supports basic arithmetic operations, and returns the most conservative
IntVar w.r.t. range of _attrs["values"].
"""
def __init__(
self,
values: List[int],
name: str = None,
symbolic_value: Optional[sympy.Basic] = None,
) -> None:
"""Initializes an IntVar.
Parameters
----------
values : List[int]
A list of possible values of this dynamic dimension.
len(values) must be >= 2.
When len(values) == 2, the values are treated as a lower bound and an upper bound.
Both upper bound and lower bound are inclusive.
This is the default use case.
When len(values) > 2, the first / last values are treated as lower / upper bounds,
and the other values are used for internal profiling purpose.
This is a legacy use case.
name : str, optional
Name of this dimension, by default None.
This field must be set for dims which are used by input tensors.
symbolic_value: sympy.Basic, optional
The symbolic value for this IntVar. If None is provided, we will generate a symbol for this IntVar.
"""
super().__init__()
self._attrs["name"] = name
if values is None or len(values) < 2:
raise RuntimeError(
"IntVar 'values' field must have at least 2 values! values: {}, name: {}".format(
values, name
)
)
if min(values) < 0:
raise RuntimeError(
"IntVar has < 0 value! values: {}, name: {}".format(values, name)
)
self._attrs["values"] = sorted(set(values))
if len(self._attrs["values"]) == 1:
self._attrs["symbolic_value"] = self._attrs["values"][0]
self._attrs["values"] = self._attrs["values"] * 2
else:
if symbolic_value is None:
symbolic_value = symbolic.create_new_symbol(name, values)
symbolic.store_intvar(symbolic_value.name, self)
self._attrs["symbolic_value"] = symbolic_value
def __str__(self) -> str:
return pformat(self._attrs, indent=2)
def __eq__(self, another: Any) -> bool:
return (
isinstance(another, IntVar)
and self._attrs["symbolic_value"] == another._attrs["symbolic_value"]
)
def __hash__(self) -> int:
return hash(
(
self._attrs["name"],
tuple(self._attrs["values"]),
self._attrs["symbolic_value"],
)
)
def __add__(self, other: Union[Any, IntVar]) -> IntVar:
self_values = self._attrs["values"]
new_sym = self._attrs["symbolic_value"]
if isinstance(other, IntVar):
other_values = other._attrs["values"]
new_sym = new_sym + other._attrs["symbolic_value"]
elif isinstance(other, Number):
other_values = [other]
new_sym = new_sym + other
else:
raise NotImplementedError(f"Unable to do addition on {self} and {other}")
new_values = [
self_values[0] + other_values[0],
self_values[-1] + other_values[-1],
]
if new_values[0] == new_values[1]:
return IntImm(new_values[0])
return IntVar(values=new_values, symbolic_value=new_sym)
def __radd__(self, other: Union[Any, IntVar]) -> IntVar:
return self + other
def __sub__(self, other: Union[Any, IntVar]) -> IntVar:
self_values = self._attrs["values"]
new_sym = self._attrs["symbolic_value"]
if isinstance(other, IntVar):
other_values = other._attrs["values"]
new_sym = new_sym - other._attrs["symbolic_value"]
elif isinstance(other, Number):
other_values = [other]
new_sym = new_sym - other
else:
raise NotImplementedError(f"Unable to do subtraction on {self} and {other}")
new_values = [
max(0, self_values[0] - other_values[-1]),
max(0, self_values[-1] - other_values[0]),
]
if new_values[0] == new_values[1]:
return IntImm(new_values[0])
return IntVar(values=new_values, symbolic_value=new_sym)
def __rsub__(self, other: Union[Any, IntVar]) -> IntVar:
self_values = self._attrs["values"]
new_sym = self._attrs["symbolic_value"]
if isinstance(other, IntVar):
other_values = other._attrs["values"]
new_sym = other._attrs["symbolic_value"] - new_sym
elif isinstance(other, Number):
other_values = [other]
new_sym = other - new_sym
else:
raise NotImplementedError(
f"Unable to do r-subtraction on {self} and {other}"
)
new_values = [
max(0, other_values[0] - self_values[-1]),
max(0, other_values[-1] - self_values[0]),
]
if new_values[0] == new_values[1]:
return IntImm(value=new_values[0])
return IntVar(values=new_values, symbolic_value=new_sym)
def __mul__(self, other: Union[Any, IntVar]) -> IntVar:
self_values = self._attrs["values"]
new_sym = self._attrs["symbolic_value"]
if isinstance(other, IntVar):
other_values = other._attrs["values"]
new_sym = new_sym * other._attrs["symbolic_value"]
elif isinstance(other, Number):
other_values = [other]
new_sym = new_sym * other
else:
raise NotImplementedError(
f"Unable to do multiplication on {self} and {other}"
)
new_values = [
self_values[0] * other_values[0],
self_values[-1] * other_values[-1],
]
if new_values[0] == new_values[1]:
return IntImm(value=new_values[0])
return IntVar(values=new_values, symbolic_value=new_sym)
def __rmul__(self, other: Union[Any, IntVar]) -> IntVar:
return self * other
def __truediv__(self, other: Union[Any, IntVar]) -> IntVar:
self_values = self._attrs["values"]
new_sym = self._attrs["symbolic_value"]
if isinstance(other, IntVar):
other_values = other._attrs["values"]
new_sym = new_sym / other._attrs["symbolic_value"]
elif isinstance(other, Number):
other_values = [other]
new_sym = new_sym / other
else:
raise NotImplementedError(f"Unable to do division on {self} and {other}")
new_values = [
math.floor(self_values[0] / max(1, other_values[-1])),
math.ceil(self_values[-1] / max(1, other_values[0])),
]
if new_values[0] == new_values[1]:
return IntImm(value=new_values[0])
return IntVar(values=new_values, symbolic_value=new_sym)
def __rtruediv__(self, other: Union[Any, IntVar]) -> IntVar:
self_values = self._attrs["values"]
new_sym = self._attrs["symbolic_value"]
if isinstance(other, IntVar):
other_values = other._attrs["values"]
new_sym = other._attrs["symbolic_value"] / new_sym
elif isinstance(other, Number):
other_values = [other]
new_sym = other / new_sym
else:
raise NotImplementedError(f"Unable to do r-division on {self} and {other}")
new_values = [
math.floor(other_values[0] / max(1, self_values[-1])),
math.ceil(other_values[-1] / max(1, self_values[0])),
]
if new_values[0] == new_values[1]:
return IntImm(value=new_values[0])
return IntVar(values=new_values, symbolic_value=new_sym)
[docs] def lower_bound(self) -> int:
"""Returns lower bound of this dynamic dim."""
return self._attrs["values"][0]
[docs] def upper_bound(self) -> int:
"""Returns upper bound of this dynamic dim."""
return self._attrs["values"][-1]
[docs] def symbolic_value(self):
"""Returns the symbolic value of this dynamic dim."""
return self._attrs["symbolic_value"]
[docs] def pseudo_code(self, with_shape=False) -> str:
return (
self._attrs["name"]
if self._attrs["name"] is not None
else f"IntVar({str(self._attrs['values'])})"
)
[docs]class IntImm(IntVar):
"""
An IntImm represents a static dimension.
IntVar (see above) and IntImm are used together to represent a Tensor's shape.
"""
def __init__(
self,
value: int,
name: str = None,
) -> None:
"""Initializes an IntImm.
Parameters
----------
value : int
Value of this static dimension.
name : str, optional
Name of this dimension, by default None.
This field must be set for dims which are used by input tensors.
"""
if not isinstance(value, int):
raise RuntimeError(
"IntImm only takes an int value! Name: {}, current value: {}".format(
name, value
)
)
Node.__init__(self) # pylint: disable=W0233
self._attrs["name"] = name
self._attrs["values"] = [value]
self._attrs["symbolic_value"] = value
def __eq__(self, another: Union[int, IntVar]) -> bool:
if isinstance(another, int):
return self.value() == another
return (
isinstance(another, IntImm)
and self._attrs["values"] == another._attrs["values"]
)
def __hash__(self) -> int:
return super().__hash__()
[docs] def value(self) -> int:
"""Returns value of this IntImm."""
return self._attrs["values"][0]
[docs] def pseudo_code(self, with_shape=False) -> str:
return str(self.value())
[docs]class JaggedDim(Node):
"""
A class representing a single jagged dimension encoded within a JaggedIntVar.
Each instance contains the min and max value for the variable-length jagged
dimension. It is also associated with the rank-1 offsets Tensor representing
the layout of the jagged dimension within the JaggedIntVar. The offsets are
associated with the JaggedDim instances after creation, while creating
a jagged tensor with the make_jagged op.
See the docstring of the JaggedIntVar class for details.
"""
def __init__(
self,
min_value: IntVar,
max_value: IntVar,
):
"""Initializes a JaggedDim.
Parameters
----------
min_value : IntVar
Minimum possible value of the jagged dimension.
max_value : IntVar
Maximum possible value of the jagged dimension.
"""
if isinstance(min_value, int):
min_value = IntImm(min_value)
if isinstance(max_value, int):
max_value = IntImm(max_value)
if min_value.lower_bound() < 0:
raise ValueError(f"{min_value=}, but must be non-negative.")
if min_value.lower_bound() > max_value.upper_bound():
raise ValueError(f"{min_value=} can't be larger than {max_value=}.")
super().__init__()
self._attrs["values"] = [min_value, max_value]
self._attrs["offsets"] = None
def __eq__(self, another: JaggedDim) -> bool:
return (
isinstance(another, JaggedDim)
and self.min_value() == another.min_value()
and self.max_value() == another.max_value()
and self.offsets() == another.offsets()
)
def __str__(self) -> str:
attrs = dict(self._attrs)
if self._attrs["offsets"] is not None:
attrs["offsets"] = {"name": self._attrs["offsets"]._attrs["name"]}
return str(attrs)
[docs] def min_value(self) -> IntVar:
"""The minimum possible value of the JaggedDim."""
return self._attrs["values"][0]
[docs] def max_value(self) -> IntVar:
"""The maximum possible value of the JaggedDim."""
return self._attrs["values"][1]
[docs] def offsets(self) -> Optional[Tensor]:
"""The rank-1 offsets Tensor associated with the JaggedDim"""
return self._attrs["offsets"]
[docs] def pseudo_code(self, with_shape=False) -> str:
return f"JaggedDim({str(self._attrs['values'])})"
[docs]class JaggedIntVar(IntVar):
"""
JaggedIntVar is a specific case of IntVar that encodes one or more jagged
dimensions within itself. JaggedIntVar is used as the first dimension in
jagged Tensors' shape (this is, basically, what makes a Tensor jagged).
E.g., a JaggedIntVar with a single JaggedDim represents a single dynamic
dimension encoding a batch of variable sequence length. For the batch
size of B, in some sources this is indicated as sum_B(N_B): the sum of
individual sequence lengths: N_1, N_2, ..., N_B of B sequences. This sum
is represented as a single dynamic dimension: total_length, with B being
defined by the batch_dim.
Because JaggedIntVar is an IntVar, it can be treated so by the AIT ops
that are unaware of the jagged Tensor semantics. But the ops that are
aware can interpret the JaggedIntVar as the first dimension of the jagged
Tensor by specifically processing the underlying batch_dim and jagged_dims.
If there is more than one JaggedDim in a JaggedIntVar, those jagged dimensions
are nested within the single dynamic dimension. E.g., if there are two JaggedDims,
the JaggedIntVar represents a batch of B (batch_dim) variable-length sequences,
each in turn consisting of variable-length sequences. In principle, the nesting
can be arbitrarily deep, but in practice it's usually just a single JaggedDim.
JaggedIntVar should not be created directly. Please use the make_jagged op
for creating a jagged Tensor from a normal Tensor, the offsets, and the
metadata (like batch_dim and jagged_dims). The make_jagged op creates the
corresponding JaggedIntVar under the hood.
"""
def __init__(
self,
total_length: IntVar,
batch_dim: IntVar,
jagged_dims: List[JaggedDim],
):
"""Initializes a JaggedIntVar.
Parameters
----------
total_length : IntVar
The existing IntVar defining the total length sum_B(N_B) of the
JaggedIntVar. The "name" and "values" attributes of the JaggedIntVar
are the same as those of the total_length. This allows transparent
treatment of the jagged Tensor as dense by non-jagged-aware ops.
Must be a dynamic dim (IntVar, not IntImm).
batch_dim : IntVar
The batch dimension B in the sum_B(N_B) representation of the
JaggedIntVar. Specifies the number of (outermost) variable-length
sequences encoded within the JaggedIntVar. Must be a dynamic dim
(IntVar, not IntImm).
jagged_dims : List[JaggedDim]
One or more jagged dimension encoded in the JaggedIntVar. Each
JaggedDim specifies the bounds of one level of nested jaggedness
of the JaggedIntVar. See the class docstring for details.
The list must contain at least one JaggedDim. All JaggedDims
in the list must have their offsets already set to the
corresponding rank-1 Tensors.
"""
if total_length is None or type(total_length) != IntVar:
raise TypeError(
"total_length must be dynamic (IntVar), "
f"but given {type(total_length).__name__}."
)
if not jagged_dims or not all(
isinstance(dim, JaggedDim) for dim in jagged_dims
):
raise TypeError(
"jagged_dims must be a non-empty list of JaggedDims, "
f"but given {jagged_dims}."
)
offsets_types = set()
for i, dim in enumerate(jagged_dims):
if dim.offsets() is None:
raise ValueError(
f"JaggedDim {i} in the jagged_dims list has no associated offsets. "
"This probably means that the JaggedIntVar is instantiated directly. "
"Instead, jagged Tensor must be created by calling the make_jagged op."
)
else:
offsets_type = dim.offsets()._attrs["dtype"]
if offsets_type not in ["int32", "int64"]:
raise TypeError(
"The offsets Tensors can be either int32 or int64, "
f"but given the Tensor of type {offsets_type}."
)
offsets_types.add(offsets_type)
if len(offsets_types) > 1:
raise TypeError(
"All offsets Tensors must be of the same type,"
f" but given the Tensors of different types: {offsets_types}."
)
super().__init__(
values=total_length._attrs["values"],
name=total_length._attrs["name"],
symbolic_value=total_length._attrs["symbolic_value"],
)
self._attrs["batch_dim"] = batch_dim
self._attrs["jagged_dims"] = jagged_dims
self._attrs["offsets_type"] = f"{offsets_types.pop()}_t"
self._total_length = total_length
def __eq__(self, another: JaggedIntVar) -> bool:
return (
isinstance(another, JaggedIntVar)
and self.total_length() == another.total_length()
and self.batch_dim() == another.batch_dim()
and self.jagged_dims() == another.jagged_dims()
)
def __hash__(self) -> int:
return hash((self._attrs["name"], tuple(self._attrs["values"])))
[docs] def total_length(self) -> IntVar:
"""The total_length dimension the JaggedIntVar is based on."""
return self._total_length
[docs] def batch_dim(self) -> IntVar:
"""The batch_dim of the JaggedIntVar."""
return self._attrs["batch_dim"]
[docs] def jagged_dims(self) -> List[JaggedDim]:
"""The jagged_dims of the JaggedIntVar."""
return self._attrs["jagged_dims"]
[docs] def offsets_type(self) -> str:
"""The type of the offsets of the JaggedIntVar's jagged_dims."""
return self._attrs["offsets_type"]
[docs] def offsets_var_name(self) -> str:
"""The name of the offsets struct variable in runtime."""
name = self._attrs["name"]
if name is None:
raise RuntimeError("The JaggedIntVar is not named yet")
return f"{name}_jagged_offsets"
[docs] def offsets_struct_type(self) -> str:
"""The type of the offsets struct variable used in runtime."""
num_jagged_dims = len(self.jagged_dims())
return f"ait::JaggedOffsets<{self.offsets_type()}, {num_jagged_dims}>"
[docs] def get_max_dense_shape(self) -> List[IntVar]:
"""
Returns a list of IntVars representing the maximum dense shape
(rectangular volume) that the JaggedIntVar can correspond to.
The result has the batch_dim as the first item and the IntImm
with the max_value of each JaggedDim that follows.
"""
result = [self.batch_dim()]
for dim in self.jagged_dims():
result.append(dim.max_value())
return result
[docs]def get_aligned_size(shape: List[IntVar], dtype: str, alignment: int = 64) -> int:
"""Returns aligned size (in bytes) of given shape and dtype.
Parameters
----------
shape: List[IntVar]
A list of IntVars, which represents the shape of a Tensor.
dtype: str
A data type string.
alignment: int
Alignment requirement (in bytes). Default alignment is 64 bytes.
Returns
----------
int
Size (in bytes) of this shape with dtype, aligned in alignment bytes.
"""
size = reduce(lambda cur, dim: cur * dim.upper_bound(), shape, 1)
size = size * get_dtype_size(dtype)
if size % alignment != 0:
size = int((size // alignment + 1) * alignment)
return size
class _ConstantTensorData(ABC):
"""
Represents data to be stored in a Tensor.
The data can be owned or unowned; each subclass should
implement its own setup and cleanup logic.
Note that this class is different from the blobs that are used
in the Python API (e.g. in Run(), compile_model). Those
blobs must represent unowned GPU memory; subclasses of _ConstantTensorData
may be owned or unowned, and may or may not reside in host memory.
Why is this separate class useful? During compilation, we have no way to
allocate memory on the GPU, so host memory must be used to store tensors that
we introduce (e.g. padding tensors). At the same time, when lowering PyTorch models,
we may want to store owned GPU data in the graph.
"""
def __init__(self, dtype: str):
super().__init__()
self.dtype = normalize_dtype(dtype)
@abstractmethod
def to_bytes(self) -> bytes:
"""
Converts the stored data to a byte string.
Called during codegen to save the ConstantTensor to the
.so.
"""
pass
def size(self) -> int:
"""
The number of bytes stored. Should be equal to
len(self.to_bytes()).
"""
return len(self.to_bytes())
def is_dtype(self, dtype: str) -> bool:
return normalize_dtype(dtype) == self.dtype
def __len__(self) -> int:
return self.size()
class _HostConstantTensorData(_ConstantTensorData):
"""
The simplest possible _ConstantTensorData; just a
lightweight wrapper around some host data.
"""
def __init__(self, data: bytes, dtype: str = "float16"):
super().__init__(dtype)
self.data = data
def to_bytes(self) -> bytes:
return self.data
class _TorchConstantTensorData(_ConstantTensorData):
"""
Wraps a torch.Tensor for storage in _ConstantTensorData.
"""
def __init__(self, tensor):
super().__init__(torch_dtype_to_string(tensor.dtype))
self.tensor = tensor
def to_bytes(self) -> bytes:
if self.size() == 0:
return b""
import ctypes
t = self.tensor.contiguous().cpu().detach()
# We used to do tensor().numpy().tobytes() here,
# but numpy doesn't support bfloat16 natively,
# so we obtain the underlying C array.
# Results are flaky when tensor is not bound to a local variable.
raw_array = ctypes.cast(
t.data_ptr(), ctypes.POINTER(ctypes.c_ubyte * self.size())
)
return bytes(raw_array.contents)
def size(self) -> int:
"""
Override size() to avoid D2H copy.
"""
return self.tensor.element_size() * self.tensor.nelement()
class _NumpyConstantTensorData(_ConstantTensorData):
"""
Wraps an ndarray for storage in _ConstantTensorData.
"""
def __init__(self, arr: np.ndarray):
super().__init__(str(arr.dtype))
self.arr = arr
def to_bytes(self) -> bytes:
return self.arr.tobytes()
[docs]class Tensor(Node):
"""
A Tensor represents a piece of data, which is used as an input / output of an Operator.
Both Tensor and Operator are used at model compilation stage.
"""
def __init__(
self,
shape: List[IntVar],
name: Optional[str] = None,
src_ops: Iterable[Node] = None,
dst_ops: Iterable[Node] = None,
dtype: str = "float16",
is_input: bool = False,
is_output: bool = False,
value: Any = None,
is_view_of: Any = None,
is_internal_constant: bool = False,
skip_constant_folding: bool = False,
check_nan_and_inf: bool = False,
check_outputs: bool = False,
original_name: str = None,
) -> None:
"""Initializes a Tensor.
Parameters
----------
shape : List[IntVar]
Shape of this Tensor.
name : str, optional
Name of this Tensor. By default, it's None.
src_ops : Iterable[Node], optional
Source operators of this Tensor which write to this Tensor.
By default, it's an empty set.
dst_ops : Iterable[Node], optional
Destination operators of this Tensor which take this Tensor as
one of their inputs.
By default, it's an empty set.
dtype : str, optional
Date type of this Tensor. By default, it's "float16".
is_input : bool, optional
Whether this Tensor is an input Tensor of a graph.
Note that constant Tensors (e.g. weights) are NOT input Tensors.
is_output : bool, optional
Whether this Tensor is an output Tensor of a graph.
value : Any, optional
The value of this Tensor. When value is set and shape is an
empty list, this Tensor is used to represent a number.
is_view_of : Any, optional
Whether this Tensor is a view of another Tensor.
is_internal_constant: bool, optional
Whether this constant tensor could be modified.
skip_constant_folding: bool, optional
Whether this tensor participates in constant folding.
check_nan_and_inf : bool, optional
Whether or not to check this tensor is nan or inf during runtime.
check_outputs : bool, optional
Whether or not to print this tensor's value out during runtime.
original_name : str, optional
Original name of this tensor before making it AIT friendly
"""
super().__init__()
self._attrs["shape"] = self._convert_shape(shape)
self._attrs["name"] = name
self._attrs["src_ops"] = StableSet(src_ops)
self._attrs["dst_ops"] = StableSet(dst_ops)
self._attrs["dtype"] = dtype
self._attrs["is_output"] = is_output
self._attrs["is_input"] = is_input
self._attrs["is_param"] = False
self._attrs["is_internal_constant"] = is_internal_constant
self._attrs["skip_constant_folding"] = skip_constant_folding
# True if this is an internal tensor that aliases an output through
# a view. Set up in mark_param_tensor
self._attrs["has_output_aliases"] = False
# For special views. When an output is a view of an input/constant/other
# output, this attribute points to that view. Note that this is not the
# same as is_view_of if the output is a view of a view. This is set up
# in the mark_param_tensor graph pass.
self._attrs["external_tensor"] = None
# link to original tensor if this tensor is a view
self._attrs["is_view_of"] = is_view_of
if is_view_of:
self._attrs["dtype"] = is_view_of._attrs["dtype"]
self._attrs["value"] = value
src_deps = [src_op._attrs["depth"] for src_op in self._attrs["src_ops"]]
self._attrs["depth"] = max(src_deps) + 1 if len(src_deps) > 0 else 0
# Offset into internal memory slab, set by memory planning
self._attrs["offset"] = None
# Data to be bound for constant folding. See _bind_data.
self._attrs["data"] = None
self._attrs["constant_folding_output_idx"] = None
self._attrs["check_nan_and_inf"] = check_nan_and_inf
self._attrs["check_outputs"] = check_outputs
self._attrs["original_name"] = original_name
def __str__(self) -> str:
output = {}
for key in self._attrs.keys():
if key in ("src_ops", "dst_ops") and self._attrs[key] is not None:
output[key] = [x._attrs["name"] for x in self._attrs[key]]
else:
output[key] = self._attrs[key]
return pformat(output, indent=2)
def _convert_shape(self, shape: List[Union[int, IntVar]]) -> List[IntVar]:
"""
Converts from a list of ints / IntVars to a list of IntVars.
"""
ret = []
for v in shape:
if isinstance(v, int):
ret.append(IntImm(v))
elif isinstance(v, IntVar):
ret.append(v)
else:
raise RuntimeError(f"Unsupported dim type: {type(v)}, dim: {v}")
return ret
[docs] def shape(self) -> List[IntVar]:
"""
Returns the shape of the tensor.
It should not be used directly in IR.
"""
return self._attrs["shape"]
def _rank(self) -> int:
"""
Returns the rank of the tensor.
It should not be used directly in IR.
"""
return len(self._attrs["shape"])
def _size(self, dim) -> IntVar:
"""
Gets the size of tensor at dim=dim.
dim must be between [-rank, rank - 1].
It should not be used directly in IR, use ops.size(dim) instead.
"""
return self._attrs["shape"][wrap_dim(dim, self._rank())]
[docs] def dtype(self) -> str:
"""Returns Tensor's data type str."""
return self._attrs["dtype"]
[docs] def src_ops(self) -> Set[Operator]:
"""Returns a set of source operators which write to this Tensor."""
return self._attrs["src_ops"]
[docs] def dst_ops(self) -> Set[Operator]:
"""Returns a set of destination operators which read from this Tensor."""
return self._attrs["dst_ops"]
[docs] def is_a_const_num(self) -> bool:
"""Returns whether this Tensor represents a constant number."""
return len(self._attrs["shape"]) == 0 and self._attrs["value"] is not None
[docs] def is_jagged(self) -> bool:
"""Whether the Tensor is jagged (the first dim is JaggedIntVar)."""
return len(self._attrs["shape"]) > 0 and isinstance(
self._attrs["shape"][0], JaggedIntVar
)
[docs] def size_bytes(self, alignment: int = 1) -> int:
"""Returns actual size (in bytes) of this Tensor."""
return get_aligned_size(self._attrs["shape"], self.dtype(), alignment)
[docs] def pseudo_code(self, with_shape=True) -> str:
name = self._attrs["name"]
if name is None:
name = "None"
args = [f"name={name}"]
if with_shape:
shapes = ", ".join([dim.pseudo_code() for dim in self._attrs["shape"]])
args.append(f"shape=[{shapes}]")
data = self._attrs["data"]
if data is not None:
args.append(f"data=({data.size()} bytes)")
if self.is_jagged():
args.append("jagged=True")
return f"Tensor({', '.join(args)})"
def _bind_data(self, data: _ConstantTensorData) -> None:
"""
Bind some data to this tensor.
- This tensor must not have any src_ops().
- The provided data's size in bytes much match the maximum size of this tensor
Tensors with bound data can participate in constant folding.
"""
if self.src_ops():
raise ValueError(
f"Cannot bind tensor {self._attrs['name']}; {len(self.src_ops())=} > 0"
)
dtype = self._attrs["dtype"]
if not data.is_dtype(dtype):
raise ValueError(
f"data's dtype did not match: expected {dtype}, got {data.dtype}"
)
tensor_size = self.size_bytes(alignment=1)
if tensor_size != len(data):
raise ValueError(
(
"ConstantTensor's maximum size is not equal to len(data)! "
f"Got {len(data)=}, but expected at least {tensor_size} bytes. "
"Check that the ConstantTensor's size and dtype are correct."
)
)
self._attrs["data"] = data
def __deepcopy__(self, memo):
result = Tensor(self.shape())
memo[id(self)] = result
result._attrs = copy.deepcopy(self._attrs, memo)
return result
def __add__(self, other: Any) -> Tensor:
return OP_REGISTRY.get("ADD")(self, other)
def __radd__(self, other: Any) -> Tensor:
return OP_REGISTRY.get("ADD")(other, self)
def __sub__(self, other: Any) -> Tensor:
return OP_REGISTRY.get("SUB")(self, other)
def __rsub__(self, other: Any) -> Tensor:
return OP_REGISTRY.get("SUB")(other, self)
def __mul__(self, other: Any) -> Tensor:
return OP_REGISTRY.get("MUL")(self, other)
def __rmul__(self, other: Any) -> Tensor:
return OP_REGISTRY.get("MUL")(other, self)
def __truediv__(self, other: Any) -> Tensor:
return OP_REGISTRY.get("DIV")(self, other)
def __rtruediv__(self, other: Any) -> Tensor:
return OP_REGISTRY.get("DIV")(other, self)
def __neg__(self) -> Tensor:
return OP_REGISTRY.get("MUL")(-1, self)
def _create_host_zero_tensor(
shape: List[Union[int, IntVar]],
name: str = None,
dst_ops: Set[Node] = None,
dtype: str = "float16",
is_output: bool = False,
is_internal_constant: bool = True,
):
"""
Create a zero tensor stored on the host machine.
"""
shape = [dim if isinstance(dim, IntVar) else IntImm(dim) for dim in shape]
zeros = _HostConstantTensorData(
b"\x00" * get_aligned_size(shape, dtype, alignment=1), dtype=dtype
)
tensor = Tensor(shape, name, dst_ops=dst_ops, dtype=dtype, is_output=is_output)
tensor._attrs["is_internal_constant"] = is_internal_constant
tensor._bind_data(zeros)
return tensor
[docs]class IntVarTensor(Tensor):
"""
A special tensor which represents an IntImm / IntVar.
This Tensor can be used as inputs of some Operators (e.g. reshape, layernorm).
An IntVarTensor instead of IntVar is used here to keep reference to
src_ops and dst_ops.
"""
def __init__(
self,
int_var: IntVar,
name: str = None,
src_ops: Set[Node] = None,
dst_ops: Set[Node] = None,
dtype: str = "float16",
is_input: bool = False,
is_output: bool = False,
value: Any = None,
is_view_of: Any = None,
) -> None:
"""Initializes an IntVar Tensor.
Parameters
----------
int_var: IntVar
The underlying IntVar variable.
"""
shape = []
super().__init__(
shape,
name,
src_ops,
dst_ops,
dtype=dtype,
is_input=is_input,
is_output=is_output,
)
self._attrs["int_var"] = int_var
self._attrs["symbolic_value"] = int_var._attrs["symbolic_value"]
[docs] def pseudo_code(self, with_shape=True) -> str:
return f"IntVarTensor({self._attrs['int_var'].pseudo_code()})"
def __deepcopy__(self, memo):
result = IntVarTensor(self._attrs["int_var"])
memo[id(self)] = result
result._attrs = copy.deepcopy(self._attrs, memo)
return result
def __add__(self, other: Any) -> Tensor:
return OP_REGISTRY.get("INT_ADD")(self, other)
def __radd__(self, other: Any) -> Tensor:
return OP_REGISTRY.get("INT_ADD")(other, self)
def __sub__(self, other: Any) -> Tensor:
return OP_REGISTRY.get("INT_SUB")(self, other)
def __rsub__(self, other: Any) -> Tensor:
return OP_REGISTRY.get("INT_SUB")(other, self)
def __mul__(self, other: Any) -> Tensor:
return OP_REGISTRY.get("INT_MUL")(self, other)
def __rmul__(self, other: Any) -> Tensor:
return OP_REGISTRY.get("INT_MUL")(other, self)
def __truediv__(self, other: Any) -> Tensor:
return OP_REGISTRY.get("INT_DIV")(self, other)
def __rtruediv__(self, other: Any) -> Tensor:
return OP_REGISTRY.get("INT_DIV")(other, self)
[docs]class DynamicProfileStrategy(Enum):
"""Dynamic profiling strategy enum.
Instances are used to select profiling strategy when there are dynamic dims.
"""
# Always use an IntVar's min value to profile.
MIN = 1
# Always use an IntVar's max value to profile.
MAX = 2
# Profiling according to an IntVar's value list.
# For testing purpose only.
HINTS = 3
[docs]@dataclass
class ExecItem:
"""A data class to store profiling info."""
profiling_key: str
exec_cond: str
algo: str
[docs]class Operator(Node):
"""Base class for all operators"""
def __init__(self) -> None:
"""Initializes the operator."""
super().__init__()
self._attrs["inputs"] = None
self._attrs["has_profiler"] = False
def __call__(self, *args: List[Tensor]) -> List[Tensor]:
"""Performs offline shape inference and constructs the model graph.
Parameters
-------
*args : List[Tensor]
Input tensors.
Returns
-------
List[Tensor]
Output tensors.
Raises
------
NotImplementedError
"""
raise NotImplementedError
def __deepcopy__(self, memo):
result = type(self)(**self._get_op_attributes())
memo[id(self)] = result
result._attrs = copy.deepcopy(self._attrs, memo)
return result
def _set_depth(self) -> None:
"""
Sets operator depth and dst_ops.
This function must be called by each operator subclass inside
__call__() method once self._attrs["inputs"] is set.
"""
max_depth = 0
if self._attrs["inputs"] is not None:
for inp in self._attrs["inputs"]:
max_depth = max(max_depth, inp._attrs["depth"])
inp._attrs["dst_ops"].add(self)
self._attrs["depth"] = max_depth
def __str__(self) -> str:
"""Generates a debug string."""
output = {}
for key in self._attrs.keys():
if (
key in ("inputs", "args", "outputs", "original_inputs")
and self._attrs[key] is not None
):
output[key] = [x._attrs["name"] for x in self._attrs[key]]
else:
output[key] = self._attrs[key]
return pformat(output, indent=2)
[docs] def gen_profiler(
self, workdir: str = None, dynamic_profiling_strategy=None
) -> None:
"""Generates source files for profiling purpose.
Parameters
----------
workdir : str, optional
The directory to generate source files.
dynamic_profiling_strategy: DynamicProfileStrategy, optional
A dynamic profiling strategy, used to filter generated profiles at compile time.
See also: :func:`~aitemplate.compiler.transform.profile.profile`
"""
return
[docs] def profile(
self,
workdir="./",
devices=None,
dynamic_profiling_strategy=DynamicProfileStrategy.MAX,
) -> None:
"""Selects the fastest kernel configurations.
Parameters
----------
workdir : str, optional
The directory which contains source files, by default "./"
devices: list, optional
A list of device ids which can be used for profiling.
dynamic_profiling_strategy: DynamicProfileStrategy, optional
Profiling strategy used when there are dynamic dims.
By default, MAX is used, i.e. to profile a dynamic range, an upper bound will be used.
"""
return
[docs] def gen_function(self) -> str:
"""Generates function source code string.
Returns
-------
str : a string which contains C++ function implementation source code.
Raises
------
NotImplementedError
"""
raise NotImplementedError("gen_function is not defined for {}".format(self))
# APIs below are for graph transformations.
def _get_op_attributes(self) -> Dict[str, Any]:
"""
Returns a dictionary of the core attributes of the op.
The core attributes are attributes that are required to create an op, for
example, the FuncEnum for a elementwise op.
This is used when we need to copy the op with identical behaviour.
Returns
-------
Dict of attributes
"""
return {}
# APIs below are for pseudo code generation.
def _inputs_for_pseudo_code(self):
return self._attrs["inputs"]
def _outputs_for_pseudo_code(self):
return self._attrs["outputs"]
def _args_for_pseudo_code(self):
return [f"{key}={value}" for key, value in self._get_op_attributes().items()]
def _pseudo_code_helper(self, node: Any, with_shape: bool) -> str:
if isinstance(node, list):
if len(node) > 3 and isinstance(node[0], Tensor):
return ",\n".join(self._pseudo_code_helper(n, with_shape) for n in node)
else:
return ", ".join(self._pseudo_code_helper(n, with_shape) for n in node)
if isinstance(node, Node):
return node.pseudo_code(with_shape)
return str(node)
[docs] def pseudo_code(self, with_shape=True):
args = self._pseudo_code_helper(self._args_for_pseudo_code(), with_shape)
inputs = self._pseudo_code_helper(self._inputs_for_pseudo_code(), with_shape)
outputs = self._pseudo_code_helper(self._outputs_for_pseudo_code(), with_shape)
name = self._attrs.get("name", None)
return f"# {name}\n({outputs}) \n= {self._attrs['op']}({args})(\n{inputs})\n"