# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
dtype definitions and utility functions of AITemplate
"""
_DTYPE2BYTE = {
"bool": 1,
"float16": 2,
"float32": 4,
"float": 4,
"int": 4,
"int32": 4,
"int64": 8,
"bfloat16": 2,
}
# Maps dtype strings to AITemplateDtype enum in model_interface.h.
# Must be kept in sync!
# We can consider defining an AITemplateDtype enum to use on the Python
# side at some point, but stick to strings for now to keep things consistent
# with other Python APIs.
_DTYPE_TO_ENUM = {
"float16": 1,
"float32": 2,
"float": 2,
"int": 3,
"int32": 3,
"int64": 4,
"bool": 5,
"bfloat16": 6,
}
[docs]def get_dtype_size(dtype: str) -> int:
"""Returns size (in bytes) of the given dtype str.
Parameters
----------
dtype: str
A data type string.
Returns
----------
int
Size (in bytes) of this dtype.
"""
if dtype not in _DTYPE2BYTE:
raise KeyError(f"Unknown dtype: {dtype}. Expected one of {_DTYPE2BYTE.keys()}")
return _DTYPE2BYTE[dtype]
[docs]def normalize_dtype(dtype: str) -> str:
"""Returns a normalized dtype str.
Parameters
----------
dtype: str
A data type string.
Returns
----------
str
normalized dtype str.
"""
if dtype == "int":
return "int32"
if dtype == "float":
return "float32"
return dtype
[docs]def dtype_str_to_enum(dtype: str) -> int:
"""Returns the AITemplateDtype enum value (defined in model_interface.h) of
the given dtype str.
Parameters
----------
dtype: str
A data type string.
Returns
----------
int
the AITemplateDtype enum value.
"""
if dtype not in _DTYPE_TO_ENUM:
raise ValueError(
f"Got unsupported input dtype {dtype}! Supported dtypes are: {list(_DTYPE_TO_ENUM.keys())}"
)
return _DTYPE_TO_ENUM[dtype]
[docs]def dtype_to_enumerator(dtype: str) -> str:
"""Returns the string representation of the AITemplateDtype enum
(defined in model_interface.h) for the given dtype str.
Parameters
----------
dtype: str
A data type string.
Returns
----------
str
the AITemplateDtype enum string representation.
"""
def _impl(dtype):
if dtype == "float16":
return "kHalf"
elif dtype == "float32" or dtype == "float":
return "kFloat"
elif dtype == "int32" or dtype == "int":
return "kInt"
elif dtype == "int64":
return "kLong"
elif dtype == "bool":
return "kBool"
elif dtype == "bfloat16":
return "kBFloat16"
else:
raise AssertionError(f"unknown dtype {dtype}")
return f"AITemplateDtype::{_impl(dtype)}"
def is_same_dtype(dtype1: str, dtype2: str) -> bool:
"""Returns True if dtype1 and dtype2 are the same dtype and False otherwise.
Parameters
----------
dtype1: str
A data type string.
dtype2: str
A data type string.
Returns
----------
bool
whether dtype1 and dtype2 are the same dtype
"""
return normalize_dtype(dtype1) == normalize_dtype(dtype2)