Source code for aitemplate.compiler.ops.jagged.jagged_lengths_to_offsets

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
Define jagged_lengths_to_offsets op
"""

from typing import List

from aitemplate.backend import registry
from aitemplate.backend.target import Target
from aitemplate.compiler.base import IntVar, Operator, Tensor


[docs]class jagged_lengths_to_offsets(Operator): """ Given a 1D Tensor of lengths of the sequences in a jagged Tensor, returns the corresponding 1D Tensor of offsets. The latter is the inclusive sum of the lengths prepended by a zero. Args: lengths (Tensor): 1D Tensor of sequence lengths, [B]-shaped. Returns: offsets (Tensor): 1D Tensor of sequence offsets, [B+1]-shaped. """ def __init__(self): super().__init__() self._attrs["op"] = "jagged_lengths_to_offsets" self._attrs["has_profiler"] = False def _infer_shape(self, lengths: Tensor) -> List[IntVar]: batch_size = lengths.shape()[0] # the offsets are 1 element longer than the lengths offsets_size = IntVar( values=[ batch_size.lower_bound() + 1, batch_size.upper_bound() + 1, ] ) return [offsets_size] def __call__( self, lengths: Tensor, ) -> Tensor: if len(lengths.shape()) != 1: raise ValueError(f"The lengths Tensor must be 1D, but got {lengths=}.") if lengths._attrs["dtype"] not in ("int32", "int64"): raise ValueError( f"The lengths Tensor must be int32 or int64, but got {lengths=}." ) self._attrs["inputs"] = [lengths] self._set_depth() output_shape = self._infer_shape(lengths) offsets = Tensor( output_shape, src_ops={self}, dtype=lengths._attrs["dtype"], ) # set the workspace to empirically determined large enough value sizeof_dtype = 4 if lengths._attrs["dtype"] == "int32" else 8 self._attrs["workspace"] = max( 2**16, 16 * sizeof_dtype * offsets.shape()[0].upper_bound(), ) self._attrs["outputs"] = [offsets] return offsets
[docs] def gen_function(self) -> str: target = Target.current() func = registry.get(f"{target.name()}.{self._attrs['op']}.gen_function") return func(self._attrs)