# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
This pass performs the following fusion:
t0 = tensor([1, M, N])
x0 = expand(t0, [B, M, N])
x1 = bmm(x0, t1) # or x1 = bmm(t1, x0)
==>
x1 = bmm(t0, t1) # or x1 = bmm(t1, t0)
The basic idea behind the transformation is that we leverage bmm's
broadcasting capability to achieve the same functionality as expand.
"""
from typing import List
from aitemplate.compiler.base import Operator, Tensor
from aitemplate.compiler.tensor_accessor import TensorAccessor
from aitemplate.compiler.transform.toposort import toposort
from aitemplate.compiler.transform.transform_utils import (
remove_single_tensor_op_from_sorted_graph,
sanitize_sorted_graph,
)
def _can_fuse(expand_op: Operator, bmm_op: Operator) -> bool:
"""
determine if expand_op and bmm_op can be fused
"""
from aitemplate.compiler.ops.tensor.expand import ( # inner import to break circular import
ExpandDimensionType,
)
expand_output = expand_op._attrs["outputs"][0]
if expand_output._attrs["is_output"]:
return False
expand_inputs = expand_op._attrs["inputs"]
expand_input_shape = expand_inputs[0]._attrs["shape"]
expand_output_shape = expand_output._attrs["shape"]
# not valid for bmm
if len(expand_output_shape) != 3:
return False
if len(expand_input_shape) == 2:
# In this case, we are expanding the batch dim
assert (
expand_input_shape[0] == expand_output_shape[1]
and expand_input_shape[1] == expand_output_shape[2]
), f"invalid {expand_input_shape=} and {expand_output_shape=}"
return True
# not valid for bmm
if len(expand_input_shape) != 3:
return False
if expand_op._attrs["dim_types"][0] != ExpandDimensionType.EXPAND_DIM:
return False
bmm_inputs = bmm_op._attrs["inputs"]
bmm_a = bmm_inputs[0]
bmm_b = bmm_inputs[1]
if expand_output is bmm_a:
return expand_output_shape[0] == bmm_a._attrs["shape"][0]
if expand_output is bmm_b:
return expand_output_shape[0] == bmm_b._attrs["shape"][0]
return False
[docs]def fuse_expand_bmm(sorted_graph: List[Tensor], workdir: str = None) -> List[Tensor]:
"""
Transform expand + bmm into a single bmm op.
Parameters
----------
sorted_graph : List[Tensor]
Input graph
workdir : str, optional
workdir, by default None
Returns
-------
List[Tensor]
Optimized graph
"""
for tensor in sorted_graph:
src_ops = tensor._attrs["src_ops"]
if len(src_ops) != 1:
continue
op = list(src_ops)[0]
if op._attrs["op"] != "expand":
continue
expand_op = op
expand_output = expand_op._attrs["outputs"][0]
dst_ops = expand_output._attrs["dst_ops"]
if len(dst_ops) != 1:
continue
next_op = list(dst_ops)[0]
if not next_op._attrs["op"].startswith("bmm_"):
continue
if not _can_fuse(expand_op, next_op):
continue
for int_var_tensor in expand_op._attrs["inputs"][1:]:
int_var_tensor._attrs["dst_ops"].discard(expand_op)
expand_op._attrs["inputs"] = [expand_op._attrs["inputs"][0]]
remove_single_tensor_op_from_sorted_graph(expand_op)
old_tensor_accessors = next_op._attrs["input_accessors"]
assert (
old_tensor_accessors[0].stride_dim is None
and old_tensor_accessors[1].stride_dim is None
), f"next_op {next_op._attrs['name']} tensor accessors are expected to be None"
bmm_inputs = next_op._attrs["inputs"]
# refresh tensor accessors, which will be used by codegen
next_op._attrs["input_accessors"] = [TensorAccessor(t) for t in bmm_inputs]
sorted_graph = toposort(sorted_graph)
return sanitize_sorted_graph(sorted_graph)