# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Perform fusions for permute+bmm operators.
"""
from typing import Callable, List, Optional, Set, Tuple, Type, Union
from aitemplate.compiler import ops
from aitemplate.compiler.base import IntImm, IntVar, Operator, Tensor
from aitemplate.compiler.ops.gemm_universal import (
bmm_ccr,
bmm_crr,
bmm_rcr,
bmm_rrr,
gemm_rcr,
gemm_rcr_bias,
gemm_rrr,
gemm_rrr_bias,
)
from aitemplate.compiler.ops.tensor import permute021
from aitemplate.compiler.ops.tensor.permute import permute
from aitemplate.compiler.transform.fuse_utils import extract_only_one_op
from aitemplate.compiler.transform.transform_utils import (
copy_src_op_attributes,
copy_tensor_attributes,
remove_dst_op_from_tensor,
remove_tensor_from_sorted_graph,
replace_tensor,
sanitize_sorted_graph,
)
from aitemplate.utils import alignment
# pylint: disable=C0103,W0612
def _try_extract_one_mm_op(ops: Set[Union[None, Operator]]) -> Union[None, Operator]:
"""
Helper function that returns the matmul op from src_ops() or dst_ops() call.
Return None if there's no bmm ops
"""
if ops is None:
return None
for op in ops:
if op._attrs["op"].startswith("bmm") or op._attrs["op"].startswith("gemm"):
return op
return None
def _fuse_permute_impl(
sorted_graph: List[Tensor],
source: List[Type[Operator]],
targets: List[Union[None, Type[Operator]]],
gemm_condition: Optional[Callable],
permute_condition: Optional[Callable],
) -> Tuple[bool, List[Tensor]]:
"""
Function that fuses [permute021 + bmm] into corresponding bmm op.
Parameters
----------
sorted_graph : List[Tensor]
AIT graph to run fusion
source: List[Type[Operator]]
Combination of permute+bmm ops to be fused.
This should be of len-2
targets: List[Type[Operator]]
To be fused bmm that matches the source.
This should be of len 2, which corresponds to the operator that does
permute A and permute B respectively
gemm_condition: Optional[Callable]
If not None, we apply on the gemm op to check whether it requires fusion.
permute_condition: Optional[Callable]
If not None, we apply on the permute op to check whether it requires fusion.
"""
assert len(source) == 2, "Source should have 2 elements, got {} instead".format(
len(source)
)
new_sorted_graph = []
fused = False
to_replace = {}
for tensor in sorted_graph:
if tensor in to_replace:
new_sorted_graph.append(to_replace[tensor])
replace_tensor(tensor, to_replace[tensor])
del to_replace[tensor]
continue
new_sorted_graph.append(tensor)
if fused:
continue
if tensor._attrs["is_output"]:
continue
permute_op = extract_only_one_op(tensor._attrs["src_ops"])
bmm_op = _try_extract_one_mm_op(tensor._attrs["dst_ops"])
if permute_op is None or bmm_op is None:
continue
if permute_op._attrs["op"] != source[0]()._attrs["op"]:
continue
if bmm_op._attrs["op"] != source[1]()._attrs["op"]:
continue
if gemm_condition is not None and not gemm_condition(bmm_op):
continue
if permute_condition is not None and not permute_condition(permute_op):
continue
assert len(permute_op._attrs["inputs"]) == 1
assert len(bmm_op._attrs["outputs"]) == 1
inputs = list(bmm_op._attrs["inputs"])
if targets[0] is None and inputs[0] == tensor:
continue
if targets[1] is None and inputs[1] == tensor:
continue
input_tensor = permute_op._attrs["inputs"][0]
output_tensor = bmm_op._attrs["outputs"][0]
# TODO: Check whether the input is weight to have better compile time
# optimization on preprocessing of pad etc.
permute_shape = tensor.shape()
permute_dtype = tensor.dtype()
prepermute_shape = input_tensor.shape()
prepermute_dtype = input_tensor.dtype()
if (
isinstance(prepermute_shape[-1], IntImm)
and (
not alignment.valid_alignment(
prepermute_shape[-1].value(), prepermute_dtype
)
)
and isinstance(permute_shape[-1], IntImm)
and alignment.valid_alignment(permute_shape[-1].value(), permute_dtype)
):
# We don't run the permute+bmm fusion if the permute op could
# turn an invalid alignment into a valid alignment.
continue
fused = True
remove_dst_op_from_tensor(bmm_op._attrs["inputs"], bmm_op)
target = None
if inputs[0] == tensor:
target = targets[0]
inputs[0] = input_tensor
elif inputs[1] == tensor:
target = targets[1]
inputs[1] = input_tensor
else:
raise RuntimeError(
"bmm inputs are {}, not matching permute's output tensor {}".format(
inputs, tensor
)
)
if not tensor.dst_ops():
# Remove permute configs if this is the last bmm consuming the tensor
remove_dst_op_from_tensor(input_tensor, permute_op)
remove_tensor_from_sorted_graph(tensor)
new_tensor = target()(*inputs)
copy_tensor_attributes(new_tensor, output_tensor)
copy_src_op_attributes(new_tensor, output_tensor)
to_replace[output_tensor] = new_tensor
return (fused, sanitize_sorted_graph(new_sorted_graph))
[docs]def fuse_permute_bmm_and_gemm(
sorted_graph: List[Tensor], workdir: str = None
) -> List[Tensor]:
"""Fuse [permute021 + bmm] and [permute(0, 1) + gemm].
Note that for the latter fusion, we require that this pass takes
place before any gemm + elementwise fusions.
Parameters
----------
sorted_graph : List[Tensor]
Input graph
workdir : str, optional
working dir, by default None
Returns
-------
List[Tensor]
Fused graph
"""
def _need_broadcast_gemm(op: Operator):
if not op._attrs["op"].startswith("gemm"):
return False
inputs = op._attrs["inputs"]
# cutlass's bmm assigns the batch size to grid_z, which
# cannot exceeds 65535
MAX_B_DIM_VAL = 65535
def _valid_shape(shape: List[Union[IntImm, IntVar]]):
b_dim = shape[0]
if isinstance(b_dim, IntImm):
b_dim_val = b_dim.value()
else:
b_dim_val = b_dim.upper_bound()
return b_dim_val <= MAX_B_DIM_VAL
input_shape_0 = inputs[0].shape()
input_shape_1 = inputs[1].shape()
if len(input_shape_0) != 3 and len(input_shape_1) != 3:
return False
if len(input_shape_0) == 3 and not _valid_shape(input_shape_0):
return False
if len(input_shape_1) == 3 and not _valid_shape(input_shape_1):
return False
return True
def _is_transpose(op: Operator):
if op._attrs["op"] != "permute":
return False
dims = op._attrs["dims"]
return dims == [1, 0]
permute_mm_patterns = (
([permute021, bmm_ccr], [bmm_rcr, bmm_crr], None, None),
([permute021, bmm_crr], [bmm_rrr, bmm_ccr], None, None),
([permute021, bmm_rcr], [bmm_ccr, bmm_rrr], None, None),
([permute021, bmm_rrr], [bmm_crr, bmm_rcr], None, None),
([permute021, gemm_rcr], [bmm_ccr, bmm_rrr], _need_broadcast_gemm, None),
([permute021, gemm_rrr], [bmm_crr, bmm_rcr], _need_broadcast_gemm, None),
(
[permute021, gemm_rcr_bias],
[ops.gemm_universal.bmm_ccr_add, ops.gemm_universal.bmm_rrr_add],
_need_broadcast_gemm,
None,
),
(
[permute021, gemm_rrr_bias],
[ops.gemm_universal.bmm_crr_add, None],
_need_broadcast_gemm,
None,
),
([permute, gemm_rcr], [None, gemm_rrr], None, _is_transpose),
([permute, gemm_rrr], [None, gemm_rcr], None, _is_transpose),
)
graph_transformed = True
while graph_transformed:
graph_transformed = False
for source, targets, gemm_condition, permute_condition in permute_mm_patterns:
fused, sorted_graph = _fuse_permute_impl(
sorted_graph, source, targets, gemm_condition, permute_condition
)
graph_transformed |= fused
return sorted_graph