Source code for aitemplate.compiler.transform.fuse_ops

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
Perform operator fusions.
"""

import collections
import itertools
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set

from aitemplate.compiler.base import Operator, Tensor
from aitemplate.compiler.ops.common import elementwise, fused_elementwise
from aitemplate.compiler.ops.common.epilogue import FuncEnum
from aitemplate.compiler.ops.groupnorm.groupnorm import group_norm
from aitemplate.compiler.ops.groupnorm.groupnorm_swish import group_norm_swish
from aitemplate.compiler.ops.layernorm import layernorm_sigmoid_mul
from aitemplate.compiler.transform import transform_utils
from aitemplate.compiler.transform.fuse_utils import transform_simple_fusion_patterns
from aitemplate.compiler.transform.toposort import toposort

# pylint: disable=C0103,W0612


_LOGGER = logging.getLogger(__name__)


class SimpleDisjointSet:
    def __init__(self):
        self.node_to_set_mapping: Dict[Any, Set[Any]] = {}

    def add(self, node: Any, dependent_nodes: Optional[Set[Any]]) -> None:
        if node in self.node_to_set_mapping:
            return

        if dependent_nodes is None or len(dependent_nodes) == 0:
            self.node_to_set_mapping[node] = {node}
            return

        current_set = {
            node  # node should also be considered to decide if a new_set can be added.
        }
        for dependent in dependent_nodes:
            if dependent is None or dependent not in self.node_to_set_mapping:
                continue
            new_set = self.node_to_set_mapping.get(dependent)

            if _detect_cycle(current_set | new_set):
                continue

            current_set.update(new_set)
            for new_node in new_set:
                self.node_to_set_mapping[new_node] = current_set
        self.node_to_set_mapping[node] = current_set

    def get_node_groups(self) -> List[Set[Any]]:
        node_groups = []
        visited = set()
        for group in self.node_to_set_mapping.values():
            addr = id(group)
            if addr not in visited:
                visited.add(addr)
                node_groups.append(group)
        return node_groups


def _find_fusable_elementwise_ops(op: Operator) -> Set[Operator]:
    """
    Given an elementwise op, returns a list of parent elementwise ops
    which can be fused with this elementwise op.
    """

    # Get parent ops.
    dependent_ops = set()
    for input_tensor in op._attrs["inputs"]:
        dependent_ops.update(input_tensor._attrs["src_ops"])
    original_ops = set(dependent_ops)

    # First, filter out all non-elementwise ops.
    to_be_removed_set = set()
    for op in dependent_ops:
        if op._attrs["op"] != "elementwise":
            to_be_removed_set.add(op)
        else:
            # Assuming there are two elementwise ops, op1 and op2, where op1 is a
            # parent op of op2. If op1's output is an output tensor, or if op1 is
            # consumed by other non-elementwise ops, op1 cannot be fused with op2.
            output = op._attrs["outputs"][0]
            if output._attrs["is_output"]:
                to_be_removed_set.add(op)
                continue
            for next_op in output.dst_ops():
                if next_op._attrs["op"] != "elementwise":
                    to_be_removed_set.add(op)

    dependent_ops = dependent_ops - to_be_removed_set

    # Then get all connected elementwise ops at the last layer.
    while True:
        for op1 in dependent_ops:
            # If op1 is an ancestor of op2 but not a parent of op2,
            # op1 and op2 cannot be fused. Remove op1 and only
            # keep op2.
            for op2 in dependent_ops:
                if op1 is op2:
                    continue
                if transform_utils.is_ancestor(
                    op1, op2
                ) and not transform_utils.is_parent(op1, op2):
                    to_be_removed_set.add(op1)

            # If op1 is an ancestor of a removed op,
            # op1 and op cannot be fused. Remove op1.
            for op2 in list(to_be_removed_set):
                if transform_utils.is_ancestor(op1, op2):
                    to_be_removed_set.add(op1)

        prev_len = len(dependent_ops)
        dependent_ops = dependent_ops - to_be_removed_set
        new_len = len(dependent_ops)
        if prev_len == new_len:
            break

    _LOGGER.debug(
        f"original op set: {original_ops}, to_be_removed_set: {to_be_removed_set}, final_set: {dependent_ops}",
    )
    return dependent_ops


[docs]@dataclass class FusedElementwiseInfo: partitioned_ops: List[Operator] inputs: List[Tensor] outputs: List[Tensor] external_inputs: List[Tensor] external_outputs: List[Tensor]
def _partition_subgraphs(ops: Set[Operator]) -> Dict[str, Set[Operator]]: """ Given ops of candidate graph of fused_elementwise op graph and partition into subgraph based on output shape, returns dict of {output shape: ops to form subgraph based on the shape} """ # Partition graph of elementwise into subgraph based on output shape. output_op_map = collections.defaultdict(set) for op in ops: shapes = [] # Find output nodes for output_tensor in op._attrs["outputs"]: if ( output_tensor._attrs["is_output"] or len(output_tensor._attrs["dst_ops"] - ops) > 0 ): shapes.append("_".join(map(str, output_tensor._attrs["shape"]))) # Find anscestor of output node. # Outputs with the same shape should form the same graph if shapes: key = "|".join(shapes) op_set = output_op_map[key] for anc_op in ops: if transform_utils.is_ancestor(anc_op, op): op_set.add(anc_op) op_set.add(op) return output_op_map def _get_inputs_outputs( partitioned_ops: Set[Operator], all_ops: Set[Operator] ) -> List[List[Tensor]]: """ Given ops of a partitioned subgraph based on output shape, and ops of full graph to form a complete graph with fused_elementwise op, returns all inputs/outputs of the ops and the external input/output of the subgraph, which will serve as input/output of fused_elementwise op. """ external_inputs, external_outputs = [], [] tmp_inputs, tmp_outputs = [], [] for op in partitioned_ops: for input_tensor in op._attrs["inputs"]: tmp_inputs.append(input_tensor) src_ops = set(input_tensor._attrs["src_ops"]) if (len(src_ops) == 0 or len(src_ops - all_ops) > 0) and ( not input_tensor.is_a_const_num() ): external_inputs.append(input_tensor) assert op in input_tensor._attrs["dst_ops"] for output_tensor in op._attrs["outputs"]: tmp_outputs.append(output_tensor) dst_ops = set(output_tensor._attrs["dst_ops"]) if output_tensor._attrs["is_output"] or len(dst_ops - all_ops) > 0: external_outputs.append(output_tensor) assert len(output_tensor._attrs["src_ops"]) == 1 assert list(output_tensor._attrs["src_ops"])[0] == op # dict.fromkeys takes unique tensors and preserves the ordering. external_inputs = list(dict.fromkeys(external_inputs)) external_outputs = list(dict.fromkeys(external_outputs)) tmp_inputs = list(dict.fromkeys(tmp_inputs)) tmp_outputs = list(dict.fromkeys(tmp_outputs)) assert set(external_inputs) == set(tmp_inputs) - set( tmp_outputs ), "external_inputs: {} is not equal to tmp_inputs: {} - tmp_outputs: {}.".format( external_inputs, tmp_inputs, tmp_outputs ) assert ( len(set(tmp_outputs) - set(tmp_inputs) - set(external_outputs)) == 0 ), "tmp_outputs: {} - tmp_inputs: {} - external_outputs: {} is not empty.".format( tmp_outputs, tmp_inputs, external_outputs ) assert ( len(set(external_outputs) - set(tmp_outputs)) == 0 ), "external_outputs: {} - tmp_outputs: {} is not empty.".format( external_outputs, tmp_outputs ) return [tmp_inputs, tmp_outputs, external_inputs, external_outputs] def _collect_info( output_op_map: Dict[str, Set[Operator]], all_ops: Set[Operator], sorted_graph: List[Tensor], ) -> List[FusedElementwiseInfo]: """ Collects information for each fused_elementwise op: 1. Provide op_list in topological order so fuse_elementwise backend can emit operations in order. 2. Provide inputs outputs info of each subgraph. This need to happen before fuse ops are created, i.e. graph get changed. Returns list of fused_op_info, which contains: partitioned op list in topological order, all inputs/outputs of elementwise ops and their external input/output, serving as input/output of fused_elementwise op. """ info_list = [] for op_set in output_op_map.values(): # Toposort the op_set into op_list # because fuse_elementwise stores elementwise ops in topological order topo_set = set() op_list = [] for tensor in sorted_graph: topo_set.add(tensor) to_remove = set() for op in op_set: if all([arg in topo_set for arg in op._attrs["inputs"]]): op_list.append(op) to_remove.add(op) op_set = op_set - to_remove assert ( not op_set ), "Unable to find topological order of op list for fused_elementwise!" # Get all inputs/outputs of elementwise ops and their external input/output, # which will serve as input/output of fused_elementwise op. inputs_outputs = _get_inputs_outputs(op_list, all_ops) fused_op_info = FusedElementwiseInfo(op_list, *inputs_outputs) info_list.append(fused_op_info) return info_list def _create_fuse_ops(info_list: List[FusedElementwiseInfo]) -> None: """ Creates fused ops based on info we collected. First is to update elementwise ops' inputs/outputs within the subgraph; Second is to create fused_elementwise ops where their inputs/outputs are external inputs/outputs of the subgraph. """ for info in info_list: op_set = set(info.partitioned_ops) for tensor in itertools.chain(info.inputs, info.outputs): tensor._attrs["src_ops"] = tensor._attrs["src_ops"] - op_set tensor._attrs["dst_ops"] = tensor._attrs["dst_ops"] - op_set fused_elementwise( info.partitioned_ops, info.external_inputs, info.external_outputs, ) def _detect_cycle(group: Set[Operator]) -> bool: """ Given a group of ops, to detect if they would form cycles, i.e. --> group_ops / / A <----- we need to find all parents of all ops in that group and see if any parent's ancester (execluding the ones already in the group) exists in the group. """ parents = [o for op1 in group for i in op1._attrs["inputs"] for o in i.src_ops()] for op1 in group: for op2 in set(parents) - group: if transform_utils.is_ancestor(op1, op2): return True return False
[docs]def fuse_elementwise(sorted_graph: List[Tensor], workdir: str = None) -> List[Tensor]: """ Given a sorted graph, returns a sorted graph with fused_elementwise ops on fusable elementwise ops. """ disjoint_set = SimpleDisjointSet() for tensor in sorted_graph: src_ops = tensor._attrs["src_ops"] if src_ops is None or len(src_ops) != 1: continue src_op = list(src_ops)[0] if src_op._attrs["op"] == "elementwise": disjoint_set.add( src_op, _find_fusable_elementwise_ops(src_op), ) to_be_fused_op_groups = disjoint_set.get_node_groups() for ops in to_be_fused_op_groups: # Partition subgraph based on output shape. output_op_map = _partition_subgraphs(ops) # Collect information to create fuse ops. info_list = _collect_info(output_op_map, ops, sorted_graph) # Create fuse ops. _create_fuse_ops(info_list) sorted_graph = toposort(sorted_graph) return transform_utils.sanitize_sorted_graph(sorted_graph)
[docs]def process_singleton_elementwise( sorted_graph: List[Tensor], workdir: str = None ) -> List[Tensor]: """ A dummy pass which enables codegen for any elementwise op without fusing it with neighbors """ disjoint_set = SimpleDisjointSet() for tensor in sorted_graph: src_ops = tensor._attrs["src_ops"] if src_ops is None or len(src_ops) != 1: continue src_op = list(src_ops)[0] if src_op._attrs["op"] == "elementwise": disjoint_set.add( src_op, {src_op}, ) to_be_fused_op_groups = disjoint_set.get_node_groups() for ops in to_be_fused_op_groups: # Partition subgraph based on output shape. # output_op_map = {op._attrs["op"]: set(op) for op in ops} output_op_map = _partition_subgraphs(ops) # Collect information to create fuse ops. info_list = _collect_info(output_op_map, set(ops), sorted_graph) # Create fuse ops. _create_fuse_ops(info_list) sorted_graph = toposort(sorted_graph) return transform_utils.sanitize_sorted_graph(sorted_graph)
def _fuse_layernorm_sigmoid_mul(sorted_graph: List[Tensor]) -> List[Tensor]: to_be_fused_op_groups = [] for tensor in sorted_graph: src_ops = tensor._attrs["src_ops"] if src_ops is None or len(src_ops) != 1: continue src_op = list(src_ops)[0] if src_op is None: continue if src_op._attrs["op"] != "layernorm": continue layer_norm = src_op dst_ops = list(tensor._attrs["dst_ops"]) if not dst_ops: continue # layernorm as the last op in the graph next_op = dst_ops[0] if ( next_op._attrs["op"] != "elementwise" or next_op._attrs["func"] != FuncEnum.SIGMOID ): continue sigmoid = next_op next_tensor = sigmoid._attrs["outputs"][0] # layernorm + sigmoid dst_ops = list(next_tensor._attrs["dst_ops"]) if not dst_ops: continue next_op = dst_ops[0] if ( next_op._attrs["op"] != "elementwise" or next_op._attrs["func"] != FuncEnum.MUL ): continue mul = next_op if layernorm_sigmoid_mul.is_valid(layer_norm, sigmoid, mul): to_be_fused_op_groups.append((layer_norm, sigmoid, mul)) for ops in to_be_fused_op_groups: layernorm_sigmoid_mul(*ops) return transform_utils.sanitize_sorted_graph(sorted_graph) def _fuse_groupnorm_sigmoid_mul(sorted_graph: List[Tensor]) -> List[Tensor]: fusion_patterns = [ ( ( group_norm(num_groups=2, num_channels=4), elementwise(FuncEnum.SIGMOID), elementwise(FuncEnum.MUL), ), group_norm_swish, ) ] graph = transform_simple_fusion_patterns(sorted_graph, fusion_patterns) return graph def fuse_ops(sorted_graph: List[Tensor], workdir: str = None) -> List[Tensor]: funcs = [ _fuse_layernorm_sigmoid_mul, _fuse_groupnorm_sigmoid_mul, ] for func in funcs: sorted_graph = func(sorted_graph) return sorted_graph