Source code for aitemplate.compiler.ops.layernorm.batch_layernorm_sigmoid_mul

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
x: [b, m, n]
gamma: [b, n]
beta: [b, n]
"""
from typing import List

from aitemplate.compiler.base import IntImm
from aitemplate.compiler.ops.layernorm.layernorm import layernorm

# pylint: disable=C0103,W0221,W0102,W0223


[docs]class batch_layernorm_sigmoid_mul(layernorm): """batch_layernorm_sigmoid_mul op. This op expects the normalized_shape to be 1D. """ def __init__(self, normalized_shape: List[IntImm] = None) -> None: super().__init__(normalized_shape) self._attrs["op"] = "batch_layernorm_sigmoid_mul" def _sanity_check(self, x, gamma, beta): input_len = len(self._attrs["inputs"]) if input_len < 1 or input_len > 4: raise NotImplementedError( f"Expect 1 ~ 4 inputs for Layernorm. Actual #inputs: {input_len}" ) (x_shape, gamma_shape, beta_shape) = layernorm.get_input_shapes(x, gamma, beta) if len(x_shape) != 3: raise NotImplementedError( f"Layernorm input must be a 3-d matrix, current shape: {x_shape}" ) if gamma_shape is not None: if len(gamma_shape) != 2: raise NotImplementedError( f"Layernorm gamma must be a 2-d matrix, current shapes: {gamma_shape}" ) if beta_shape is not None: if len(beta_shape) != 2: raise NotImplementedError( f"Layernorm beta must be a 2-d matrix, current shapes: {beta_shape}" ) # x: [b, m, n] # gamma: [b, n] # beta: [b, n] if gamma_shape is not None: if x_shape[2] != gamma_shape[1]: raise RuntimeError( f"Layernorm inputs mismatch! x shape: {x_shape}, gamma shape: {gamma_shape}" ) if x_shape[0] != gamma_shape[0]: raise RuntimeError( f"Layernorm inputs mismatch! x shape: {x_shape}, gamma shape: {gamma_shape}" ) if beta_shape is not None: if x_shape[2] != beta_shape[1]: raise RuntimeError( f"Layernorm inputs mismatch! x shape: {x_shape}, beta shape: {beta_shape}" ) if x_shape[0] != beta_shape[0]: raise RuntimeError( f"Layernorm inputs mismatch! x shape: {x_shape}, beta shape: {beta_shape}" ) normalized_shape = self._attrs["normalized_shape"] if len(normalized_shape) != 1: raise NotImplementedError( f"Layernorm normalized_shape length must be 1. Current normalized_shape: {normalized_shape}" ) if normalized_shape[0]._attrs["values"] != x_shape[2]._attrs["values"]: raise RuntimeError( f"Layernorm normalized shape is not compatible with input shape. " f"Normalized shape: {normalized_shape}, input shape: {x_shape}" )