Source code for aitemplate.compiler.transform.optimize_graph

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
Applies graph transformations.
"""

from typing import List

from aitemplate.compiler.base import Tensor
from aitemplate.compiler.transform.apply_padding import apply_padding
from aitemplate.compiler.transform.dedup_make_jagged_ops import dedup_make_jagged_ops
from aitemplate.compiler.transform.fuse_bmm_permute import fuse_bmm_permute
from aitemplate.compiler.transform.fuse_conv_elementwise import fuse_conv_elementwise
from aitemplate.compiler.transform.fuse_duplicate_fused_elementwise import (
    fuse_duplicate_fused_elementwise,
)
from aitemplate.compiler.transform.fuse_expand_bmm import fuse_expand_bmm
from aitemplate.compiler.transform.fuse_group_ops import fuse_group_ops
from aitemplate.compiler.transform.fuse_mm_elementwise import fuse_mm_elementwise
from aitemplate.compiler.transform.fuse_mm_reshape_permute import (
    fuse_mm_reshape_permute,
)
from aitemplate.compiler.transform.fuse_ops import (
    fuse_elementwise,
    fuse_ops,
    process_singleton_elementwise,
)
from aitemplate.compiler.transform.fuse_parallel_gemms import (
    fuse_parallel_gemms,
    fuse_single_source_parallel_gemms,
)
from aitemplate.compiler.transform.fuse_permute_bmm_and_gemm import (
    fuse_permute_bmm_and_gemm,
)
from aitemplate.compiler.transform.move_view_ops import move_view_op_before_concat
from aitemplate.compiler.transform.remove_elementwise_no_ops import (
    remove_elementwise_no_ops,
)
from aitemplate.compiler.transform.split_large_concat_ops import split_large_concat_ops
from aitemplate.compiler.transform.split_large_slice_scatter_ops import (
    split_large_slice_scatter_ops,
)
from aitemplate.compiler.transform.split_large_split_ops import split_large_split_ops
from aitemplate.compiler.transform.transform_memory_ops import transform_memory_ops
from aitemplate.compiler.transform.transform_merge_view_ops import merge_view_ops
from aitemplate.compiler.transform.transform_odd_alignment import (
    transform_odd_alignment,
)
from aitemplate.compiler.transform.transform_permutations import eliminate_permutations
from aitemplate.compiler.transform.transform_permute_to_reshape import (
    transform_permute_to_reshape,
)
from aitemplate.compiler.transform.transform_special_ops import transform_special_ops
from aitemplate.compiler.transform.transform_strided_ops import transform_strided_ops

from aitemplate.utils import graph_utils


[docs]def optimize_graph( sorted_graph: List[Tensor], workdir: str, optimize=True ) -> List[Tensor]: """Applies graph optimizations, including - fuse permute and bmm - fuse permute and gemm - transform odd alignment - fuse conv and elementwise - fuse gemm and elementwise - fuse elementwise ops - fuse parallel gemms - fuse group ops - transform special ops - transform strided ops - fuse bmm and permute - transform memory ops - apply padding Parameters ---------- sorted_graph : List[Tensor] Input graph workdir : str working directory Returns ------- List[Tensor] Fused graph """ funcs = [ remove_elementwise_no_ops, dedup_make_jagged_ops, fuse_permute_bmm_and_gemm, fuse_bmm_permute, fuse_expand_bmm, transform_odd_alignment, fuse_conv_elementwise, fuse_single_source_parallel_gemms, fuse_mm_elementwise, fuse_mm_reshape_permute, # make sure we run move_view_op_before_concat before transform_memory_ops move_view_op_before_concat, merge_view_ops, transform_memory_ops, fuse_ops, fuse_elementwise, # need to run before transform_strided_ops to fuse strided ops + concat # and transform_memory_ops to fuse split + concat fuse_parallel_gemms, fuse_group_ops, # This needs to be run after fuse_ops() to avoid handling elementwise # op directly. After fuse_ops, there are only FusedElementwise ops. transform_special_ops, apply_padding, # apply_padding may introduce new concats that can be fused move_view_op_before_concat, transform_memory_ops, transform_strided_ops, split_large_slice_scatter_ops, split_large_concat_ops, split_large_split_ops, transform_permute_to_reshape, transform_memory_ops, eliminate_permutations, # fuse_duplicate_fused_elementwise must run after elementwise fusion and # after passes that modify/replace a fused_elementwise's input/output accessor. fuse_duplicate_fused_elementwise, ] if not optimize: # 1 - Convert elementwise ops to singleton fused_elementwise ops # 2 - Padding also needs to be done for the model to be executable. funcs = [ process_singleton_elementwise, apply_padding, split_large_slice_scatter_ops, split_large_concat_ops, split_large_split_ops, ] for i, func in enumerate(funcs): sorted_graph = func(sorted_graph, workdir) graph_utils.dump_graph_debug_str_to_file( sorted_graph, workdir, f"{i:02}-{func.__name__}" ) return sorted_graph