Source code for aitemplate.compiler.ops.tensor.where

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#


from aitemplate import backend
from aitemplate.backend import registry
from aitemplate.compiler.base import Operator, Tensor
from aitemplate.compiler.dtype import normalize_dtype


[docs]class where(Operator): """ Return a tensor of elements selected from either input or other, depending on condition. Parameters: condition (A bool Tensor): When True (nonzero), yield input, otherwise yield other input_tensor (Tensor or Scalar): value (if input is a scalar) or values selected at indices where condition is True other_tensor (Tensor or Scalar): value (if other is a scalar) or values selected at indices where condition is False dtype: output dtype if both input_tensor and output_tensor is scalar Returns: Tensor: A tensor of shape equal to the shape of condition """ def __init__(self) -> None: super().__init__() self._attrs["op"] = "where" def __call__( self, condition: Tensor, input_tensor: Tensor, other_tensor: Tensor, dtype: str = "", ) -> Tensor: assert isinstance( condition, Tensor ), f"condition needs to be a tensor, but got {type(condition)}" assert ( condition.dtype() == "bool" ), f"condition needs to be a bool tensor, but got {condition.dtype()}" output_shape = condition.shape() args = [] inputs = [] common_dtype = None for tensor in [input_tensor, other_tensor]: if isinstance(tensor, int) or isinstance(tensor, float): tensor = Tensor(shape=[], value=tensor, dtype=common_dtype) else: assert isinstance( tensor, Tensor ), f"Unsupported data type: {type(tensor)}" assert ( tensor.shape() == output_shape ), f"Tensor shape should be the same, {tensor.shape()} != {output_shape}" if common_dtype is None: common_dtype = normalize_dtype(tensor.dtype()) else: assert common_dtype == normalize_dtype( tensor.dtype() ), f"Expect tensor of the same dtype, got {common_dtype} and {normalize_dtype(tensor.dtype())}" inputs.append(tensor) args.append(tensor) # In case where both inputs are scalars, if len(inputs) == 0: assert dtype != "", "dtype needs to be provided for scalars" common_dtype = normalize_dtype(dtype) for arg in args: arg._attrs["dtype"] = common_dtype self._attrs["args"] = [condition, *args] self._attrs["inputs"] = [condition, *inputs] self._set_depth() output = Tensor( shape=output_shape, src_ops={self}, dtype=common_dtype, ) self._attrs["outputs"] = [output] return output
[docs] def gen_function(self) -> str: target = backend.target.Target.current() func_key = f"{target.name()}.{self._attrs['op']}.gen_function" func = registry.get(func_key) return func(self._attrs)