# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Define jagged_to_padded_dense op
"""
import logging
from typing import List
from aitemplate.backend import registry
from aitemplate.backend.target import Target
from aitemplate.compiler.base import IntVar, Operator, Tensor
_LOGGER = logging.getLogger(__name__)
[docs]class jagged_to_padded_dense(Operator):
"""
Returns a dense Tensor "expanded" from the input jagged Tensor.
For each of the jagged dimensions (JaggedDims) in the jagged
Tensor's first dimension (JaggedIntVar), a separate static
dimension (IntImm) equal to the max_value of the jagged
dimension is created in the output dense Tensor's shape.
The values in the output dense Tensor that don't have corresponding
values in the input jagged Tensor are set to the padding_value.
Args:
x (Tensor): input jagged Tensor.
padding_value (float): the padding value for the output dense
Tensor's elements that don't have counterparts in the input
jagged Tensor.
Returns:
y (Tensor): a dense Tensor expanded from the input jagged Tensor x.
"""
def __init__(
self,
padding_value: float = 0,
):
super().__init__()
self._attrs["op"] = "jagged_to_padded_dense"
self._attrs["padding_value"] = padding_value
def _infer_shape(self, x: Tensor) -> List[IntVar]:
jagged_int_var = x.shape()[0]
inner_shape = x.shape()[1:]
return jagged_int_var.get_max_dense_shape() + inner_shape
def _get_op_attributes(self):
return {
"padding_value": self._attrs["padding_value"],
}
def _args_for_pseudo_code(self):
return [f"padding_value={self._attrs['padding_value']}"]
def __call__(
self,
x: Tensor,
) -> Tensor:
if not x.is_jagged():
raise RuntimeError("Input tensor x must be jagged.")
self._attrs["inputs"] = [x]
self._set_depth()
output_shape = self._infer_shape(x)
y = Tensor(output_shape, src_ops={self}, dtype=x._attrs["dtype"])
self._attrs["outputs"] = [y]
return y
[docs] def gen_function(self) -> str:
target = Target.current()
func = registry.get(f"{target.name()}.{self._attrs['op']}.gen_function")
return func(self._attrs)