Source code for aitemplate.compiler.ops.vision_ops.nms.batched_nms

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
Batched nms.
"""

import itertools
from typing import List

import jinja2

from aitemplate import backend
from aitemplate.backend import registry
from aitemplate.compiler.base import (  # noqa
    _create_host_zero_tensor,
    IntImm,
    Operator,
    Tensor,
)
from aitemplate.utils import shape_utils

# pylint: disable=C0103,W0221,W0102,W0223

EXEC_KEY_TEMPLATE = jinja2.Template(
    """
M == {{x_dim0}} && K == {{x_dim1}}
"""
)


[docs]class batched_nms(Operator): r""" Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union (IoU) in a batched fashion. NMS iteratively removes lower scoring boxes which have an IoU greater than iou_threshold with another (higher scoring) box. Note: if multiple boxes have the exact same score and satisfy the IoU criterion with respect to a reference box, the selected box is not guaranteed to be the same for different backends. * :attr:`iouThreshold` identifies the intersection-over-union (IoU) threshold which is used to discards all overlapping boxes with IoU > iouThreshold. By default 0.5. * :attr:`keep_n` identifies the number of boxes to return, by default -1 to return all. Args: boxes (Tensor[N, 4])), boxes to perform NMS on. They are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and ``0 <= y1 < y2``), and have been sorted in decreasing order of scores. Returns: Tensor: "keep" (Tensor[N]) in which each element indicates if the corresponding box is removed (element=0) or not (element=1). """ def __init__(self, iou_threshold=0.5, keep_n=-1) -> None: """Op Initialization""" super().__init__() self._attrs["op"] = "batched_nms" self._attrs["has_profiler"] = False self._attrs["keep_n"] = keep_n self._attrs["iou_threshold"] = iou_threshold self.exec_key_template = EXEC_KEY_TEMPLATE def _infer_shape(self, x: List[int]): """infer output shape""" return [x[0]] def _infer_shapes(self, x: Tensor): """infer output shape""" x_shape_values = [var._attrs["values"] for var in x._attrs["shape"]] x_shapes = itertools.product(*x_shape_values) # run infershape for each y_shapes = [] for x_shape in x_shapes: y_shape = self._infer_shape(x_shape) y_shapes.append(y_shape) def unique(vector): return sorted(set(vector)) output_shape = [] for idx in range(len(y_shapes[0])): output_shape.append( shape_utils.gen_int_var(values=unique([d[idx] for d in y_shapes])) ) return output_shape def __call__(self, x: Tensor) -> Tensor: """call the function""" self._attrs["inputs"] = [x] self._set_depth() output_shape = self._infer_shapes(x) output = Tensor(output_shape, src_ops={self}, dtype="int64") boxes_num = x._attrs["shape"][0]._attrs["values"][0] col_blocks = int((boxes_num + 64 - 1) / 64) tmp_space = col_blocks * boxes_num tmp_c = _create_host_zero_tensor( [IntImm(tmp_space)], dst_ops={self}, dtype="int64" ) self._attrs["inputs"].append(tmp_c) self._attrs["outputs"] = [output] return output def _get_op_attributes(self): return { "iou_threshold": self._attrs["iou_threshold"], "keep_n": self._attrs["keep_n"], }
[docs] def gen_function(self) -> str: """call backend function""" target = backend.target.Target.current() func_key = "{target}.{op}.gen_function".format( target=target.name(), op=self._attrs["op"] ) func = registry.get(func_key) return func(self._attrs)