Source code for aitemplate.compiler.ops.tensor.jagged_to_padded_dense

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

"""
Define jagged_to_padded_dense op
"""

import logging
from typing import List

from aitemplate.backend import registry

from aitemplate.backend.target import Target

from aitemplate.compiler.base import IntVar, Operator, Tensor

_LOGGER = logging.getLogger(__name__)


[docs]class jagged_to_padded_dense(Operator): """ Returns a dense Tensor "expanded" from the input jagged Tensor. For each of the jagged dimensions (JaggedDims) in the jagged Tensor's first dimension (JaggedIntVar), a separate static dimension (IntImm) equal to the max_value of the jagged dimension is created in the output dense Tensor's shape. The values in the output dense Tensor that don't have corresponding values in the input jagged Tensor are set to the padding_value. Args: x (Tensor): input jagged Tensor. padding_value (float): the padding value for the output dense Tensor's elements that don't have counterparts in the input jagged Tensor. Returns: y (Tensor): a dense Tensor expanded from the input jagged Tensor x. """ def __init__( self, padding_value: float = 0, ): super().__init__() self._attrs["op"] = "jagged_to_padded_dense" self._attrs["padding_value"] = padding_value def _infer_shape(self, x: Tensor) -> List[IntVar]: jagged_int_var = x.shape()[0] inner_shape = x.shape()[1:] return jagged_int_var.get_max_dense_shape() + inner_shape def _get_op_attributes(self): return { "padding_value": self._attrs["padding_value"], } def _args_for_pseudo_code(self): return [f"padding_value={self._attrs['padding_value']}"] def __call__( self, x: Tensor, ) -> Tensor: if not x.is_jagged(): raise RuntimeError("Input tensor x must be jagged.") self._attrs["inputs"] = [x] self._set_depth() output_shape = self._infer_shape(x) y = Tensor(output_shape, src_ops={self}, dtype=x._attrs["dtype"]) self._attrs["outputs"] = [y] return y
[docs] def gen_function(self) -> str: target = Target.current() func = registry.get(f"{target.name()}.{self._attrs['op']}.gen_function") return func(self._attrs)