Source code for aitemplate.compiler.ops.tensor.masked_select

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
Define masked_select op
"""
import logging
from typing import List

from aitemplate.backend import registry

from aitemplate.backend.target import Target

from aitemplate.compiler.base import IntVar, Operator, Tensor

from aitemplate.compiler.dtype import get_dtype_size

from aitemplate.utils import shape_utils


_LOGGER = logging.getLogger(__name__)


[docs]class masked_select(Operator): """ Returns a 1D tensor containing elements of the input tensor selected by the boolean mask, similar to `torch.masked_select`. Args: input (Tensor): the source tensor. mask (Tensor, boolean): the shapes of the mask tensor and the input tensor do not need to match, but they must be broadcastable. Returns: output: 1D tensor of length equal to the total number of elements in broadcast shape deduced from input and mask. The result is contained in the first `num_nonmasked` elements of output. The rest of the output tensor is not meaningful. num_nonmasked: number of the non-masked elements from the input, i.e. the length of the significant part of output. """ def __init__(self): super().__init__() self._attrs["op"] = "masked_select" self._attrs["workspace"] = 0 self._attrs["max_shape"] = None def _infer_shape(self, x: Tensor, mask: Tensor) -> List[IntVar]: input_shape = x._attrs["shape"] mask_shape = mask._attrs["shape"] broadcastable, max_shape = shape_utils.get_broadcast_max_shape( input_shape, mask_shape ) if not broadcastable: raise RuntimeError( "Tensor shapes of input and mask are not broadcastable! Shape1: {}, shape2: {}".format( input_shape, mask_shape ) ) self._attrs["max_shape"] = max_shape numel = 1 for dim in max_shape: numel *= dim.upper_bound() # Allocate temporary buffer. This empirical formula for size is deduced by looking at necessary # memory to expand input/mask and the buffer sizes requested by cub::DeviceSelect::Flagged for # different input sizes. input_workspace_size = (input_shape != max_shape) * get_dtype_size( x._attrs["dtype"] ) mask_workspace_size = (mask_shape != max_shape) * 1 # bool self._attrs["workspace"] = ( numel * (input_workspace_size + mask_workspace_size) + numel // 128 + 1024 ) _LOGGER.debug( f'Allocating {self._attrs["workspace"]} bytes for temporary buffer of masked_select op' ) # Output size can range from 0 (when all mask elements are False) to the total number of # elements in the input (when all mask elements are True). return [IntVar(values=(0, numel))] def __call__( self, x: Tensor, mask: Tensor, ) -> List[Tensor]: dtype = mask._attrs["dtype"] if dtype != "bool": raise RuntimeError("Expected mask of dtype bool, but got {}".format(dtype)) self._attrs["inputs"] = [x, mask] self._set_depth() output_shape = self._infer_shape(x, mask) output = Tensor(output_shape, src_ops={self}, dtype=x._attrs["dtype"]) self._attrs["outputs"] = [output] return output
[docs] def gen_function(self) -> str: target = Target.current() func = registry.get(f"{target.name()}.{self._attrs['op']}.gen_function") return func(self._attrs)