Source code for aitemplate.compiler.ops.tensor.relational

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

from enum import Enum

from aitemplate import backend
from aitemplate.backend import registry
from aitemplate.compiler.base import Operator, Tensor
from aitemplate.compiler.dtype import normalize_dtype


class RelationalEnum(Enum):
    GE = ">="
    LE = "<="
    LT = "<"
    GT = ">"
    EQ = "=="
    NE = "!="


class relational(Operator):
    """
    Relational operator that supports comparing a tensor to another tensor or a constant

    Parameters:
        left (Tensor): the tensor to compare

        right (Tensor or float): the tensor or value to compare

    Returns:
        Tensor: a tensor of bool
    """

    def __init__(self) -> None:
        super().__init__()
        self._attrs["op"] = "relational"

    def __call__(self, left: Tensor, right: Tensor) -> Tensor:
        assert self._attrs["func"] is not None, "No function registered"
        common_dtype = None
        assert isinstance(
            left, Tensor
        ), "Relational expects left operand to be a Tensor"
        common_dtype = normalize_dtype(left.dtype())
        left._attrs["dtype"] = common_dtype

        if isinstance(right, int) or isinstance(right, float):
            right = Tensor(shape=[], value=right, dtype=common_dtype)
        else:
            assert isinstance(
                right, Tensor
            ), "Relational expects right operand to be a Tensor or constant"
            assert (
                normalize_dtype(right.dtype()) == common_dtype
            ), f"Type promotions are not supported; got dtype {left.dtype()}, but expected {common_dtype}"
            assert (
                left.shape() == right.shape()
            ), "Relational does not support broadcasting yet. It expects tensor of same shape."
            right._attrs["dtype"] = common_dtype

        self._attrs["args"] = [left, right]
        self._attrs["inputs"] = [left] if right.is_a_const_num() else [left, right]
        self._set_depth()
        output = Tensor(left.shape(), src_ops=[self], dtype="bool")
        self._attrs["outputs"] = [output]
        return output

    def gen_function(self) -> str:
        target = backend.target.Target.current()
        func_key = f"{target.name()}.{self._attrs['op']}.gen_function"
        func = registry.get(func_key)
        return func(self._attrs)


[docs]class ge(relational): def __init__(self) -> None: super().__init__() self._attrs["func"] = RelationalEnum.GE
[docs]class le(relational): def __init__(self) -> None: super().__init__() self._attrs["func"] = RelationalEnum.LE
[docs]class gt(relational): def __init__(self) -> None: super().__init__() self._attrs["func"] = RelationalEnum.GT
[docs]class lt(relational): def __init__(self) -> None: super().__init__() self._attrs["func"] = RelationalEnum.LT
[docs]class eq(relational): def __init__(self) -> None: super().__init__() self._attrs["func"] = RelationalEnum.EQ
[docs]class ne(relational): def __init__(self) -> None: super().__init__() self._attrs["func"] = RelationalEnum.NE