# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Fuse GEMM with elementwise operations
"""
from typing import List
from aitemplate.compiler.base import Tensor
from aitemplate.compiler.ops.common.epilogue import FuncEnum
from aitemplate.compiler.ops.gemm_universal import gemm_rcr_bias_swish
from aitemplate.compiler.transform.fuse_mm_elementwise_patterns import (
get_gemm_rcr_bias_patterns,
get_patterns,
)
from aitemplate.compiler.transform.fuse_utils import (
extract_only_one_op,
is_elementwise_type,
transform_simple_fusion_patterns,
)
from aitemplate.compiler.transform.transform_utils import (
copy_tensor_attributes,
remove_dst_op_from_tensor,
remove_single_tensor_op_from_sorted_graph,
replace_tensor,
sanitize_sorted_graph,
)
# pylint: disable=C0103,C0415,W0612
def _fuse_bmm_mul_or_div_alpha(sorted_graph: List[Tensor]) -> List[Tensor]:
"""This pass fuses bmm and mul (or div) if mul's other operand is a
constant scalar tensor (i.e. which has a valid "value" attribute.
In such a case, we turn this constant value into bmm's alpha.
Note that for div cases, we assign 1/const_val to alpha.
Parameters
----------
sorted_graph : List[Tensor]
input sorted graph
Return
----------
List[Tensor]
modified sorted graph upon success. Otherwise, the original sorted
graph will be returned.
"""
for tensor in sorted_graph:
src_ops = tensor._attrs["src_ops"]
if src_ops is None or len(src_ops) != 1:
continue
src_op = list(src_ops)[0]
if src_op is None:
continue
if not src_op._attrs["op"].startswith("bmm"):
continue
bmm_op = src_op
dst_ops = list(tensor._attrs["dst_ops"])
if not dst_ops or len(dst_ops) != 1:
continue
next_op = dst_ops[0]
if next_op._attrs["op"] != "elementwise":
continue
if next_op._attrs["func"] == FuncEnum.MUL:
is_div = False
elif next_op._attrs["func"] == FuncEnum.DIV:
is_div = True
else:
continue
elem_op = next_op
elem_inputs = elem_op._attrs["inputs"]
if len(elem_inputs) != 1:
continue
elem_args = elem_op._attrs["args"]
if len(elem_args) != 2:
continue
# make sure cst_tensor is the divisor of the DIV op
if is_div and tensor == elem_args[1]:
continue
cst_tensor = elem_args[1] if tensor == elem_args[0] else elem_args[1]
# skip non-constant scalar tensor
if not cst_tensor.is_a_const_num():
continue
cst_val = cst_tensor._attrs["value"]
# let's only consider int and float builtin types. Seems that it doesn't
# make any sense to take other scalar types like str and convert it
# to a float.
if not isinstance(cst_val, (float, int)):
continue
# OK, we are good so let's add cst_val to bmm's alpha attribute
bmm_op._attrs["alpha"] = 1.0 / float(cst_val) if is_div else float(cst_val)
# remove this MUL/DIV
remove_single_tensor_op_from_sorted_graph(elem_op)
return sanitize_sorted_graph(sorted_graph)
def _fuse_gemm_rcr_bias_swish(sorted_graph: List[Tensor]) -> List[Tensor]:
"""
gemm_rcr_bias_swish(A, B) is equivalent to:
x = gemm_rcr_bias(A, B)
x1 = sigmoid(x)
return elementwise(MUL)(x, x1)
"""
new_sorted_graph = []
to_remove = set()
for tensor in sorted_graph:
if tensor in to_remove:
continue
new_sorted_graph.append(tensor)
if tensor._attrs["is_output"]:
continue
gemm_op = extract_only_one_op(tensor._attrs["src_ops"])
if gemm_op is None:
continue
if gemm_op._attrs["op"] != "gemm_rcr_bias":
continue
dst_op = list(tensor._attrs["dst_ops"])
if len(dst_op) != 2:
continue
swish_tensor = None
for idx in range(2):
other_idx = (idx + 1) % 2
if is_elementwise_type(dst_op[idx], FuncEnum.SIGMOID):
if not is_elementwise_type(dst_op[other_idx], FuncEnum.MUL):
continue
is_swish = False
output = dst_op[idx]._attrs["outputs"][0]
mul_inputs = dst_op[other_idx]._attrs["inputs"]
if mul_inputs[0] == output and mul_inputs[1] == tensor:
is_swish = True
if mul_inputs[1] == output and mul_inputs[0] == tensor:
is_swish = True
if not is_swish:
continue
swish_tensor = dst_op[other_idx]._attrs["outputs"][0]
break
if swish_tensor is None:
continue
gemm_inputs = gemm_op._attrs["inputs"]
remove_dst_op_from_tensor(gemm_inputs, gemm_op)
# Output of sigmoid and final mul of swish.
to_remove.add(dst_op[0]._attrs["outputs"][0])
to_remove.add(dst_op[1]._attrs["outputs"][0])
new_tensor = gemm_rcr_bias_swish()(*gemm_inputs)
copy_tensor_attributes(new_tensor, swish_tensor)
replace_tensor(swish_tensor, new_tensor)
new_sorted_graph[-1] = new_tensor
return sanitize_sorted_graph(new_sorted_graph)
def _transform_gemm_bias(sorted_graph: List[Tensor]) -> List[Tensor]:
return transform_simple_fusion_patterns(sorted_graph, get_gemm_rcr_bias_patterns())
def _transform_mm_elementwise(sorted_graph: List[Tensor]) -> List[Tensor]:
fusion_patterns = get_patterns()
return transform_simple_fusion_patterns(sorted_graph, fusion_patterns)
[docs]def fuse_mm_elementwise(
sorted_graph: List[Tensor], workdir: str = None
) -> List[Tensor]:
"""Fuse GEMMs with elementwise operations.
Parameters
----------
sorted_graph : List[Tensor]
Input graph
workdir : str, optional
working dir, by default None
Returns
-------
List[Tensor]
Fused graph
"""
funcs = [
_fuse_bmm_mul_or_div_alpha,
_transform_gemm_bias,
_transform_mm_elementwise,
_fuse_gemm_rcr_bias_swish,
]
for func in funcs:
sorted_graph = func(sorted_graph)
return sorted_graph