Source code for aitemplate.frontend.nn.conv1d

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
Conv1d Module.
"""

from aitemplate.compiler.ops import conv2d, conv2d_bias, squeeze, unsqueeze
from aitemplate.frontend import Tensor
from aitemplate.frontend.nn.module import Module
from aitemplate.frontend.nn.parameter import Parameter


[docs]class Conv1d(Module): r""" Conv1d module applies a 1D convolution over an input signal composed of several input planes. .. math:: \text{out}\left(B_i, \text{:}, \text{channels\_out}_j\right) = \text{bias}\left(\text{channels\_out}_j\right) + \sum_{k = 0}^{\text{channels\_in} - 1} \text{weight}\left(\text{channels\_out}_j, \text{:}, k\right) \star \text{input}\left(B_i, \text{:}, k\right) The semantics are similar to `PyTorch`_ with the following exception: dims 1 and 2 of the weight, input and output are swapped (while dim 0 remains the same). .. _PyTorch: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, dtype: str = "float16", bias: bool = False, name: str = "conv1d", ): super().__init__() self.weight = Parameter( shape=[out_channels, kernel_size, in_channels // groups], dtype=dtype, name=f"{name}_weight", ) if bias: self.bias = Parameter( shape=[out_channels], dtype=dtype, name=f"{name}_bias" ) else: self.bias = None # note that conv1d is functionally equivalent to conv2d, # but we need to reshape the input, weight and output tensors, # as well as use the correct stride, padding and dilation for the conv2d op. fwd_func = conv2d_bias if bias else conv2d self.op = fwd_func( stride=(stride, 1), pad=(padding, 0), dilate=(dilation, 1), group=groups )
[docs] def forward(self, x: Tensor) -> Tensor: r"""Applies Conv1d on the input tensor of shape :math:`(B, \text{seq\_in}, \text{channels\_in})`. The output has shape :math:`(B, \text{seq\_out}, \text{channels\_out})`, where .. math:: \text{seq\_out} = \left\lfloor\frac{\text{seq\_in} + 2 \times \text{padding} - \text{dilation} \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor """ # make the conv2d inputs 4d xu = unsqueeze(dim=2)(x) wu = unsqueeze(dim=2)(self.weight.tensor()) if self.bias is None: c2d = self.op(xu, wu) else: c2d = self.op(xu, wu, self.bias.tensor()) # make the result 3d again return squeeze(dim=2)(c2d)