Source code for aitemplate.compiler.transform.fuse_conv_patterns

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
from aitemplate.compiler.ops.common import elementwise
from aitemplate.compiler.ops.common.epilogue import FuncEnum
from aitemplate.compiler.ops.conv import (
    conv2d,
    conv2d_bias,
    conv2d_bias_add,
    conv2d_bias_add_relu,
    conv2d_bias_few_channels,
    conv2d_bias_relu,
    conv2d_bias_relu_few_channels,
    conv2d_bias_sigmoid,
    transposed_conv2d,
    transposed_conv2d_bias,
    transposed_conv2d_bias_relu,
)


def get_conv2d_bias_pattern():
    # Attribute in conv2d is not of concern, it will be passed-through directly.
    return [((conv2d(stride=1, pad=0), elementwise(FuncEnum.ADD)), conv2d_bias)]


[docs]def get_conv2d_bias_elementwise_patterns(): """ We create the pattern of fusion here. The format should be in the form of (pattern, replacement) pattern: This would be a list of operator which are chained which we want to match replacement: The op to replace pattern. """ conv2d_bias_patterns = [ ( ( conv2d_bias(stride=1, pad=0), elementwise(FuncEnum.ADD), elementwise(FuncEnum.RELU), ), conv2d_bias_add_relu, ), ( ( conv2d_bias(stride=1, pad=0), elementwise(FuncEnum.RELU), ), conv2d_bias_relu, ), ( ( conv2d_bias(stride=1, pad=0), elementwise(FuncEnum.SIGMOID), ), conv2d_bias_sigmoid, ), ] transposed_conv2d_bias_patterns = [ ( ( transposed_conv2d_bias(stride=1, pad=0), elementwise(FuncEnum.RELU), ), transposed_conv2d_bias_relu, ), ] transposed_conv2d_patterns = [ ( ( transposed_conv2d(stride=1, pad=0), elementwise(FuncEnum.ADD), elementwise(FuncEnum.RELU), ), transposed_conv2d_bias_relu, ), ( ( transposed_conv2d_bias(stride=1, pad=0), elementwise(FuncEnum.RELU), ), transposed_conv2d_bias_relu, ), ] fusion_patterns = ( conv2d_bias_patterns + transposed_conv2d_bias_patterns + transposed_conv2d_patterns ) return fusion_patterns
def get_cuda_only_conv2d_bias_elementwise_patterns(): conv2d_bias_patterns = [ ( ( conv2d_bias_few_channels(stride=1, pad=0), elementwise(FuncEnum.RELU), ), conv2d_bias_relu_few_channels, ), ( ( conv2d_bias(stride=1, pad=0), elementwise(FuncEnum.ADD), ), conv2d_bias_add, ), ] transposed_conv2d_patterns = [ ( ( transposed_conv2d(stride=1, pad=0), elementwise(FuncEnum.ADD), ), transposed_conv2d_bias, ), ] return conv2d_bias_patterns + transposed_conv2d_patterns