Source code for aitemplate.compiler.ops.tensor.full

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

from typing import List

from aitemplate import backend
from aitemplate.backend import registry
from aitemplate.compiler.base import IntImm, IntVar, Operator, Tensor
from aitemplate.compiler.dtype import get_dtype_size


[docs]class full(Operator): """ Creates a tensor of a given `shape` and `dtype` filled with the specified `fill_value` (float scalar). Args: shape (int or IntVar or List[IntVar]): the shape of the output Tensor. fill_value (int or float): the value to fill the output Tensor with. dtype (str): the dtype of the output Tensor. Returns: Tensor: a tensor of `shape` and `dtype` filled with `fill_value`. """ def __init__(self) -> None: super().__init__() self._attrs["op"] = "full" self._attrs["has_profiler"] = False def __call__( self, shape: List[IntVar], fill_value: float, dtype: str = "float16", ) -> Tensor: if isinstance(shape, (int, IntVar)): shape = [shape] if not isinstance(shape, (list, tuple)): raise TypeError(f"shape must be List[IntVar], but got {shape}.") shape = list(shape) static_shape = all([isinstance(s, (int, IntImm)) for s in shape]) if not isinstance(fill_value, (int, float)): raise TypeError(f"fill_value must be a scalar, but got {fill_value}.") fill_value = float(fill_value) # validation inside get_dtype_size(dtype) self._attrs["inputs"] = [] self._attrs["fill_value"] = fill_value # although not used downstream, these attrs # are necessary to avoid erroneously deduping # legitimately different fill op instances self._attrs["shape"] = shape self._attrs["dtype"] = dtype self._set_depth() output = Tensor( shape, src_ops={self}, dtype=dtype, skip_constant_folding=not static_shape ) self._attrs["outputs"] = [output] return output
[docs] def gen_function(self) -> str: target = backend.target.Target.current() func_key = f"{target.name()}.{self._attrs['op']}.gen_function" func = registry.get(func_key) return func(self._attrs)