Source code for aitemplate.frontend.nn.fpn_proposal

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
FPNProposal module.
"""
import numpy as np

from aitemplate.compiler import ops
from aitemplate.compiler.base import Tensor
from aitemplate.frontend.nn.proposal import generate_shifted_anchors, Proposal


def generate_fpn_anchors(im_h, im_w, feat_strides, scales, ratios, batch_size, dtype):
    """Enumerate anchors for all levels"""
    anchors = []
    for feat_stride, scale in zip(feat_strides, scales):
        anchors.append(
            generate_shifted_anchors(
                im_h,
                im_w,
                feat_stride,
                np.array(scale, dtype="float32"),
                np.array(ratios, dtype="float32"),
                batch_size,
                dtype,
            )
        )
    return anchors


[docs]class FPNProposal(Proposal): def __init__( self, im_shape, feat_strides=(4, 8, 16, 32, 64), scales=((32,), (64,), (128,), (256,), (512,)), ratios=(0.5, 1, 2), clip_box=True, nms_on=True, rpn_pre_nms_top_n=6000, rpn_post_nms_top_n=300, iou_threshold=0.3, rpn_min_size=0, level=-1, batch_size=1, dtype="float16", ): super().__init__( im_shape, feat_strides, scales, ratios, clip_box, nms_on, rpn_pre_nms_top_n, rpn_post_nms_top_n, iou_threshold, rpn_min_size, level, generate_fpn_anchors, batch_size, dtype, )
[docs] def forward(self, *args): assert len(args) >= 1 reg_list = args[0] bbox_deltas_list = [] anchors_list = [] split_size_or_sections = [] for idx, x_bbox_deltas in enumerate(reg_list): bbox_deltas = ops.reshape()(x_bbox_deltas, [-1, 4]) anchors = Tensor( shape=self._anchors[idx].shape, name="anchors_%d" % (idx + 2) ) bbox_deltas_list.append(bbox_deltas) anchors_list.append(anchors) split_size_or_sections.append(self._anchors[idx].shape[0]) bbox_deltas_cat = ops.concatenate()(bbox_deltas_list, dim=0) anchors_cat = ops.concatenate()(anchors_list, dim=0) proposals = self.box_transform(bbox_deltas_cat, anchors_cat) proposals = ops.split()(proposals, split_size_or_sections, dim=0) if self.nms_on: scores = args[1] scores_r = ops.reshape()(scores, [1, -1]) proposals_r = ops.reshape()(proposals, [1, -1, 4]) dets = ops.nms( self.rpn_pre_nms_top_n, self.rpn_post_nms_top_n, self.iou_threshold, self.rpn_min_size, )(proposals_r, scores_r) # prepare for roi-align for mask head batch_inds = Tensor( shape=[1, self.rpn_post_nms_top_n, 1], dtype=self.dtype, name="batch_inds", value=0, ) ret = ops.reshape()(ops.concatenate()([batch_inds, dets], dim=2), [-1, 5]) return ret, dets return list(proposals)