Source code for aitemplate.compiler.transform.refine_graph

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
Graph pass to dedup operators with same signatures.
"""
import logging
from typing import List

from aitemplate.compiler.base import Operator, Tensor

from aitemplate.utils.graph_utils import get_sorted_ops

# pylint: disable=C0103


_LOGGER = logging.getLogger(__name__)

SPECIAL_CHECK_FUNC_KEYS = {
    "inputs",
    "name",
    "depth",
    "outputs",
    "original_inputs",
    "original_outputs",
    "gemm_operand_groups",
    "original_name",
    "f_ab_alignment",
    "elementwise_ops",
    "args",
}


def same_tensor_type(t1: Tensor, t2: Tensor):
    if t1.dtype() != t2.dtype():
        return False
    if t1._attrs["value"] != t2._attrs["value"]:
        return False
    t1s = t1.shape()
    t2s = t2.shape()
    if len(t1s) != len(t2s):
        return False
    for d1, d2 in zip(t1s, t2s):
        if d1 != d2:
            return False
    return True


def check_inputs_outputs(key: str, o1: Operator, o2: Operator):
    # check inputs
    o1_args = o1._attrs[key]
    o2_args = o2._attrs[key]
    if len(o1_args) != len(o2_args):
        return False
    for t1, t2 in zip(o1_args, o2_args):
        if not same_tensor_type(t1, t2):
            return False
    return True


def check_fused_elementwise_ops(o1: Operator, o2: Operator):
    ops1 = o1._attrs["elementwise_ops"]
    ops2 = o2._attrs["elementwise_ops"]

    # Allow single input to simply it
    if len(o1._attrs["inputs"]) != 1:
        return False

    # Disallow multiple ops
    if len(ops1) != len(ops2) or len(ops1) != 1:
        return False

    return same_function_type(ops1[0], ops2[0])


def same_function_type(o1: Operator, o2: Operator):
    if o1._attrs["op"] != o2._attrs["op"]:
        return False

    if len(o1._attrs) != len(o2._attrs):
        return False
    keys = o1._attrs.keys()

    # ban group gemm ops
    if "unique_workspace" in keys:
        return False

    # check general attrs
    for key in keys:
        if key not in o2._attrs:
            return False
        if key not in SPECIAL_CHECK_FUNC_KEYS:
            if o1._attrs[key] != o2._attrs[key]:
                return False

    if not check_inputs_outputs("inputs", o1, o2):
        return False

    # for fused_elementwise ops
    if o1._attrs["op"] == "fused_elementwise" and (
        not check_fused_elementwise_ops(o1, o2)
    ):
        return False
    if "original_inputs" in keys:
        if not check_inputs_outputs("original_inputs", o1, o2):
            return False
    if "original_outputs" in keys:
        if not check_inputs_outputs("original_outputs", o1, o2):
            return False

    # for elementwise ops
    if "args" in keys:
        if not check_inputs_outputs("args", o1, o2):
            return False

    return True


[docs]def refine_graph(sorted_graph: List[Tensor]): """Graph pass to dedup operators with same signatures. Parameters ---------- sorted_graph : List[Tensor] Input graph """ sorted_ops = get_sorted_ops(sorted_graph) exist_func = [] refined_ops = 0 total_ops = len(sorted_ops) refined_ops_set = set() for func in sorted_ops: found = False for f in reversed(exist_func): if same_function_type(f, func): func._attrs["name"] = f._attrs["name"] found = True refined_ops += 1 break if not found: exist_func.append(func) if found: refined_ops_set.add(func._attrs["op"]) _LOGGER.debug(f"refined ops: {refined_ops_set}") _LOGGER.info(f"reduced unique ops from {total_ops} to {total_ops - refined_ops}")