Source code for aitemplate.compiler.ops.padding.ndhwc3to8

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
Common NDHWC3to8 padding op
"""

import itertools
from typing import List

import jinja2

from aitemplate import backend
from aitemplate.backend import registry
from aitemplate.compiler.base import Operator, Tensor
from aitemplate.utils import shape_utils

# pylint: disable=C0103,W0221


SHAPE_ASSIGNMENT_TEMPLATE = jinja2.Template(
    """
{{indent}}{{y_dim0}} = NO;
{{indent}}{{y_dim1}} = DO;
{{indent}}{{y_dim2}} = HO;
{{indent}}{{y_dim3}} = WO;
"""
)

SHAPE_FUNC_TEMPLATE = jinja2.Template(
    """
{{indent}}{{dtype}}NI = {{x_dim0}};
{{indent}}{{dtype}}DI = {{x_dim1}};
{{indent}}{{dtype}}HI = {{x_dim2}};
{{indent}}{{dtype}}WI = {{x_dim3}};
{{indent}}{{dtype}}NO = NI;
{{indent}}{{dtype}}DO = DI;
{{indent}}{{dtype}}HO = HI;
{{indent}}{{dtype}}WO = WI;
{{indent}}{{dtype}}CO = 8;
"""
)


[docs]class ndhwc3to8(Operator): """ Pad the 3-channel input data to 8-channel. """ def __init__(self): super().__init__() self._attrs["op"] = "ndhwc3to8" self.shape_eval_template = SHAPE_FUNC_TEMPLATE self.shape_save_template = SHAPE_ASSIGNMENT_TEMPLATE def _infer_shape(self, x: List[int]): eval_func = self.shape_eval_template.render( indent="", dtype="", x_dim0=x[0], x_dim1=x[1], x_dim2=x[2], x_dim3=x[3], x_dim4=x[4], ) output = {} exec(eval_func, output) # noqa: P204 return [ int(output["NO"]), int(output["DO"]), int(output["HO"]), int(output["WO"]), int(output["CO"]), ] def _infer_shapes(self, x: Tensor): x_shape_values = [var._attrs["values"] for var in x._attrs["shape"]] x_shapes = itertools.product(*x_shape_values) # run infershape for each y_shapes = [] for x_shape in x_shapes: y_shape = self._infer_shape(x_shape) y_shapes.append(y_shape) def unique(vector): return sorted(set(vector)) output_shape = [ shape_utils.gen_int_var(unique([d[0] for d in y_shapes])), shape_utils.gen_int_var(unique([d[1] for d in y_shapes])), shape_utils.gen_int_var(unique([d[2] for d in y_shapes])), shape_utils.gen_int_var(unique([d[3] for d in y_shapes])), shape_utils.gen_int_var(unique([d[4] for d in y_shapes])), ] return output_shape def __call__(self, x: Tensor) -> List[Tensor]: self._attrs["inputs"] = [x] self._set_depth() output_shape = self._infer_shapes(x) output = Tensor(output_shape, src_ops={self}, dtype=x.dtype()) self._attrs["outputs"] = [output] return output def _get_op_attributes(self): return { "padded_channels": self._attrs["op"].split("to")[-1], "shape_func_template": self.shape_eval_template, }
[docs] def gen_function(self) -> str: target = backend.target.Target.current() template_path = target.template_path() func_key = "{target}.{op}.gen_function".format( target=target.name(), op=self._attrs["op"] ) func = registry.get(func_key) return func( self._attrs, template_path, self.shape_eval_template, self.shape_save_template, )