Source code for aitemplate.compiler.dtype

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
dtype definitions and utility functions of AITemplate
"""


_DTYPE2BYTE = {
    "bool": 1,
    "float16": 2,
    "float32": 4,
    "float": 4,
    "int": 4,
    "int32": 4,
    "int64": 8,
    "bfloat16": 2,
}


# Maps dtype strings to AITemplateDtype enum in model_interface.h.
# Must be kept in sync!
# We can consider defining an AITemplateDtype enum to use on the Python
# side at some point, but stick to strings for now to keep things consistent
# with other Python APIs.
_DTYPE_TO_ENUM = {
    "float16": 1,
    "float32": 2,
    "float": 2,
    "int": 3,
    "int32": 3,
    "int64": 4,
    "bool": 5,
    "bfloat16": 6,
}


[docs]def get_dtype_size(dtype: str) -> int: """Returns size (in bytes) of the given dtype str. Parameters ---------- dtype: str A data type string. Returns ---------- int Size (in bytes) of this dtype. """ if dtype not in _DTYPE2BYTE: raise KeyError(f"Unknown dtype: {dtype}. Expected one of {_DTYPE2BYTE.keys()}") return _DTYPE2BYTE[dtype]
[docs]def normalize_dtype(dtype: str) -> str: """Returns a normalized dtype str. Parameters ---------- dtype: str A data type string. Returns ---------- str normalized dtype str. """ if dtype == "int": return "int32" if dtype == "float": return "float32" return dtype
[docs]def dtype_str_to_enum(dtype: str) -> int: """Returns the AITemplateDtype enum value (defined in model_interface.h) of the given dtype str. Parameters ---------- dtype: str A data type string. Returns ---------- int the AITemplateDtype enum value. """ if dtype not in _DTYPE_TO_ENUM: raise ValueError( f"Got unsupported input dtype {dtype}! Supported dtypes are: {list(_DTYPE_TO_ENUM.keys())}" ) return _DTYPE_TO_ENUM[dtype]
[docs]def dtype_to_enumerator(dtype: str) -> str: """Returns the string representation of the AITemplateDtype enum (defined in model_interface.h) for the given dtype str. Parameters ---------- dtype: str A data type string. Returns ---------- str the AITemplateDtype enum string representation. """ def _impl(dtype): if dtype == "float16": return "kHalf" elif dtype == "float32" or dtype == "float": return "kFloat" elif dtype == "int32" or dtype == "int": return "kInt" elif dtype == "int64": return "kLong" elif dtype == "bool": return "kBool" elif dtype == "bfloat16": return "kBFloat16" else: raise AssertionError(f"unknown dtype {dtype}") return f"AITemplateDtype::{_impl(dtype)}"
def is_same_dtype(dtype1: str, dtype2: str) -> bool: """Returns True if dtype1 and dtype2 are the same dtype and False otherwise. Parameters ---------- dtype1: str A data type string. dtype2: str A data type string. Returns ---------- bool whether dtype1 and dtype2 are the same dtype """ return normalize_dtype(dtype1) == normalize_dtype(dtype2)