Source code for aitemplate.compiler.ops.layernorm.group_layernorm

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
Operator definition for group_layernorm.
"""
from typing import Any, List

from aitemplate.compiler.base import IntImm, IntVarTensor, Tensor
from aitemplate.compiler.ops.layernorm.layernorm import layernorm
from aitemplate.compiler.tensor_accessor import TensorAccessor
from aitemplate.utils import shape_utils

# pylint: disable=C0103,W0221,W0102,W0223


[docs]class group_layernorm(layernorm): """group_layernorm. For each group, we expect each input to have shapes: Input shape: [M0, M1, ..., Mp, N1, N2, ..., ND] Normalized_shape: [N1, N2, ..., ND] Gamma/Beta, if not None, have the same shape as normalized_shape. Every input in the groups must have the same [M0, M1, ..., Mp] dims. """ def __init__(self, normalized_shape: List[List[IntImm]] = None) -> None: super().__init__(normalized_shape[0] if normalized_shape is not None else None) self._attrs["op"] = "group_layernorm" self._attrs["has_profiler"] = False self._attrs["default_normalized_shape"] = normalized_shape def _sanity_check(self, all_inputs): input_len = len(all_inputs) if input_len % 3 != 0: raise NotImplementedError( "Expect multiples of 3 inputs for group layernorm. " "Actual #inputs: {}".format(input_len) ) total = len(all_inputs) b = total // 3 inputs, gammas, betas = ( all_inputs[:b], all_inputs[b : 2 * b], all_inputs[2 * b :], ) assert len(inputs) > 0 assert ( len(inputs) == len(gammas) == len(betas) == len(self._attrs["normalized_shape"]) ) for x, gamma, beta, normalized_shape in zip( inputs, gammas, betas, self._attrs["normalized_shape"] ): (x_shape, gamma_shape, beta_shape) = layernorm.get_input_shapes( x, gamma, beta ) layernorm.check_shapes(x_shape, gamma_shape, beta_shape, normalized_shape) # check x in inputs have the same batch dims (Ms), it can be dynamic # x shape: B * [Ms, Ns], Ns can be different input_len = inputs[0]._rank() norm_len = len(self._attrs["normalized_shape"][0]) for k in range(input_len - norm_len): M = inputs[0].shape()[k] for i, x in enumerate(inputs, 1): x_M = x.shape()[k] assert M == x_M, f"found shape mismatch for input_{i}," f"input_0 shape: {inputs[0].shape()}, input_{i} shape: {x.shape()}" def __call__( self, inputs: List[Tensor], gammas: List[Tensor], betas: List[Tensor], normalized_shapes: List[List[Any]] = None, eps: float = 1e-5, ) -> List[Tensor]: # inputs is flattend into a single list of tensors all_inputs = inputs + gammas + betas # 'real_inputs' only contains non-None tensors real_inputs = list(inputs) # FIXME: currently, only support two cases, either all gammas are None or # all gammas are non-None self._attrs["gamma_constant"] = "1.0" if gammas[0] is not None: if any(gamma is None for gamma in gammas): raise NotImplementedError( f"expected beta not to be None, but got None: {gammas}" ) self._attrs["gamma_constant"] = None real_inputs.extend(gammas) else: if any(gamma is not None for gamma in gammas): raise NotImplementedError( f"expected all gammas to be None, but got {gammas}" ) # FIXME: currently, only support two cases, either all betas are None or # all betas are non-None self._attrs["beta_constant"] = "0.0" if betas[0] is not None: if any(beta is None for beta in betas): raise NotImplementedError( "expected beta not to be None, but got None: {betas}" ) self._attrs["beta_constant"] = None real_inputs.extend(betas) else: if any(beta is not None for beta in betas): raise NotImplementedError( f"expected all betas to be None, but got {betas}" ) if normalized_shapes is not None: self._attrs["normalized_shape"] = [] for normalized_shape in normalized_shapes: for shape in normalized_shape: # Only add source of dynamic dim to inputs if isinstance(shape, IntVarTensor) and not isinstance( shape._attrs["int_var"], IntImm ): real_inputs.append(shape) self._attrs["normalized_shape"].append( shape_utils.convert_shape_to_IntVar(normalized_shape) ) else: self._attrs["normalized_shape"] = self._attrs["default_normalized_shape"] assert isinstance(eps, float), f"eps must be float, instead it is {type(eps)}" self._attrs["eps"] = eps self._attrs["inputs"] = real_inputs self._attrs["input_accessors"] = [] self._sanity_check(all_inputs) self._set_depth() self._attrs["outputs"] = [] self._attrs["output_accessors"] = [] for x in inputs: output_shape = self._infer_shapes(x) output = Tensor(output_shape, src_ops={self}, dtype=x.dtype()) self._attrs["outputs"].append(output) self._attrs["output_accessors"].append(TensorAccessor(output)) self._attrs["input_accessors"].append(TensorAccessor(x)) return self._attrs["outputs"] def _args_for_pseudo_code(self): res = [] for shapes in self._attrs["normalized_shape"]: res.append(",".join([str(s.symbolic_value()) for s in shapes])) return [f"normalized_shape={res}"]