Source code for aitemplate.compiler.transform.constant_folding

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
import logging
import os
from typing import Dict, List, Tuple

from aitemplate import backend, compiler

from aitemplate.compiler.base import IntVarTensor, Tensor
from aitemplate.compiler.transform.memory_planning import Workspace
from aitemplate.compiler.transform.transform_utils import replace_tensor
from aitemplate.utils import graph_utils


_LOGGER = logging.getLogger(__name__)


def _create_dummy_constant_folder():
    model_container_generator = backend.codegen.ModelContainerGenerator(
        max_blob_size=0,
        max_constant_blob_size=0,
        workspace=Workspace(0, 0),
        constants_data_file=None,
        graph=[],
        output_tensors=[],
        model_name=backend.codegen.CONSTANT_FOLDER_MODEL_NAME,
    )
    return model_container_generator.generate_model()


def _make_op_names_unique(graph: List[Tensor]) -> Dict[str, str]:
    """
    To avoid ODR issues, we rename all ops in the constant folding subgraph.
    ODR issues can arise if two ops end up sharing the same name & implementation (which
    can actually happen, e.g. in the proposal op).
    """
    new_name_to_old = {}
    for tensor in graph:
        for op in tensor._attrs["src_ops"]:
            if op._attrs["name"] not in new_name_to_old:
                new_name = f"{op._attrs['name']}_constant_folding"
                new_name_to_old[new_name] = op._attrs["name"]
                op._attrs["name"] = new_name
    return new_name_to_old


def _rename_ops(graph: List[Tensor], new_name_to_old: Dict[str, str]) -> None:
    for tensor in graph:
        for op in tensor._attrs["src_ops"]:
            if op._attrs["name"] in new_name_to_old:
                op._attrs["name"] = new_name_to_old[op._attrs["name"]]


def _non_output_from_tensor(tensor: Tensor) -> Tensor:
    new_tensor = Tensor(
        shape=tensor._attrs["shape"],
        name=tensor._attrs["name"],
        src_ops=tensor._attrs["src_ops"].copy(),
        dst_ops=tensor._attrs["dst_ops"].copy(),
        dtype=tensor._attrs["dtype"],
        is_view_of=tensor._attrs["is_view_of"],
        is_internal_constant=tensor._attrs["is_internal_constant"],
        original_name=tensor._attrs["original_name"],
    )
    new_tensor._attrs["is_param"] = tensor._attrs["is_param"]
    new_tensor._attrs["data"] = tensor._attrs["data"]
    new_tensor._attrs["external_tensor"] = tensor._attrs["external_tensor"]
    return new_tensor


def _output_from_tensor(tensor: Tensor) -> Tensor:
    new_tensor = _non_output_from_tensor(tensor)
    new_tensor._attrs["is_output"] = True
    return new_tensor


def _fix_op_inputs_outputs(
    subgraph: List[Tensor], name_to_new_tensor: Dict[str, Tensor]
) -> None:
    """
    This is an unfortunate hack made necessary by the following:

    1) When constructing the constant folding subgraph, the most understandable
       thing to do is create *new* tensors so we can modify their attributes without
       affecting the original graph.
    2) However, the inputs of each tensor's src and dst ops need to be wired up to
       the new tensors since the memory planning pass will traverse the graph through those attributes.

    So, we store the mapping from tensor name to its corresponding subgraph tensor and the tensor in
    original graph.

    Before we do memory planning for constant folding, we call:
      _fix_op_inputs_outputs(subgraph, name_to_constant_folding_tensor)

    And then afterwards we restore everything with:
      _fix_op_inputs_outputs(subgraph, name_to_original_tensor)

    It would be nice if we could deep copy the src and dst ops when we create new tensors so we can
    skip the restoration step. But this is not implemented and not trivial. Thankfully, this function
    is not too hard to understand once the rationale behind it is understood.
    """
    ops = graph_utils.get_sorted_ops(subgraph)
    for op in ops:
        op._attrs["inputs"] = [
            name_to_new_tensor[tensor._attrs["name"]] for tensor in op._attrs["inputs"]
        ]

        op._attrs["outputs"] = [
            name_to_new_tensor[tensor._attrs["name"]] for tensor in op._attrs["outputs"]
        ]


def _extract_foldable_subgraph(
    sorted_graph: List[Tensor],
) -> Tuple[List[Tensor], Dict[str, Tensor], List[Tensor]]:
    """
    Extract a list of foldable nodes. A node is foldable if:
    * It has bound data, or
    * All of its inputs are foldable.

    The subgraph returned is just a list of Tensors. All foldable
    tensors that do not have bound data are marked as outputs in
    the subgraph. The original graph is not modified.

    All tensors that do not have bound data are marked as outputs.
    This is because we want to execute the subgraph and get all
    of the new constants. Only the ones that are actually needed are put
    back into the final graph.
    """
    foldable_node_names = set()
    foldable_ops = set()
    subgraph = []

    for tensor in sorted_graph:
        if tensor._attrs["is_input"] or tensor._attrs["skip_constant_folding"]:
            continue

        name = tensor._attrs["name"]
        if tensor._attrs["data"] is not None or tensor._attrs["is_param"]:
            foldable_node_names.add(name)
            subgraph.append(tensor)
            continue
        elif isinstance(tensor, IntVarTensor):
            continue
        foldable = all(
            inp._attrs["name"] in foldable_node_names
            for op in tensor._attrs["src_ops"]
            for inp in op._attrs["inputs"]
        )

        if foldable:
            foldable_node_names.add(name)
            subgraph.append(tensor)
            for op in tensor._attrs["src_ops"]:
                foldable_ops.add(op)

    def _is_used_by_non_foldable_op(tensor: Tensor) -> bool:
        for op in tensor._attrs["dst_ops"]:
            if op not in foldable_ops:
                return True
        return False

    def _is_used_by_foldable_op(tensor: Tensor) -> bool:
        for op in tensor._attrs["dst_ops"]:
            if op in foldable_ops:
                return True
        return False

    # Now figure out which tensors can be marked as outputs.
    filtered_subgraph = []
    name_to_new_tensor = {}
    name_to_old_tensor = {}
    constant_folding_inputs = []

    for tensor in subgraph:
        name = tensor._attrs["name"]
        new_tensor = None

        if not tensor._attrs["is_param"] and (
            _is_used_by_non_foldable_op(tensor) or tensor._attrs["is_output"]
        ):
            # Tensor is required outside of the subgraph, make it an output.
            # Parameters don't need to be marked as outputs in the
            # subgraph, we already know their values.
            new_tensor = _output_from_tensor(tensor)

        elif _is_used_by_foldable_op(tensor):
            # No need to append constants that are not used by any foldable ops.
            new_tensor = _non_output_from_tensor(tensor)
            if new_tensor._attrs["is_param"]:
                constant_folding_inputs.append(new_tensor)

        if new_tensor is not None:
            name_to_new_tensor[name] = new_tensor
            name_to_old_tensor[name] = tensor
            filtered_subgraph.append(new_tensor)

    _fix_op_inputs_outputs(filtered_subgraph, name_to_new_tensor)
    return filtered_subgraph, name_to_old_tensor, constant_folding_inputs


def _constant_folding_impl(
    sorted_graph: List[Tensor],
    workdir: str,
    model_name: str,
) -> Tuple[Dict[str, Tensor], List[Tuple[str, str]], List[Tensor]]:
    model_dir = os.path.join(workdir, model_name)

    # Collect the set of output names before we do any transformations. We'll need this
    # if we end up turning outputs into constants. _extract_foldable_subgraph marks *all*
    # folded constants as outputs, so we can't just query attrs["is_output"] (see
    # extract_foldable_subgraph for more info on why that happens)
    original_output_tensors = {
        tensor._attrs["name"] for tensor in sorted_graph if tensor._attrs["is_output"]
    }

    (
        subgraph,
        name_to_old_tensor,
        constant_folding_inputs,
    ) = _extract_foldable_subgraph(sorted_graph)
    output_tensors = [tensor for tensor in subgraph if tensor._attrs["is_output"]]
    if not output_tensors:
        _LOGGER.info("No constants to fold, skipping constant folding.")
        # Write a dummy constant folder so everything still compiles.
        with open(os.path.join(model_dir, "constant_folder-generated.h"), "w") as f:
            f.write(_create_dummy_constant_folder())
        _fix_op_inputs_outputs(subgraph, name_to_old_tensor)
        return {}, [], []

    blob, constant_blob, workspace = compiler.transform.memory_planning(subgraph)
    new_name_to_old = _make_op_names_unique(subgraph)
    file_pairs = backend.codegen.gen_function_src(subgraph, workdir, model_name)
    model_container_generator = backend.codegen.ModelContainerGenerator(
        blob,
        constant_blob,
        workspace,
        constants_data_file=None,
        graph=subgraph,
        output_tensors=output_tensors,
        model_name=backend.codegen.CONSTANT_FOLDER_MODEL_NAME,
        model_dir=model_dir,
    )
    model_container_generator.append_all_tensors()
    constant_folding_model_def = model_container_generator.generate_model()
    with open(os.path.join(model_dir, "constant_folder-generated.h"), "w") as f:
        f.write(constant_folding_model_def)

    _fix_op_inputs_outputs(subgraph, name_to_old_tensor)
    _rename_ops(subgraph, new_name_to_old)
    new_tensors = {}
    for tensor in subgraph:
        if not tensor._attrs["is_param"]:
            name = tensor._attrs["name"]
            new_tensor = Tensor(
                shape=tensor._attrs["shape"],
                name=name,
                dtype=tensor._attrs["dtype"],
                is_output=name in original_output_tensors,
            )
            if name in model_container_generator.output_name_to_idx:
                new_tensor._attrs["constant_folding_output_idx"] = (
                    model_container_generator.output_name_to_idx[name]
                )
            new_tensors[name] = new_tensor

    return new_tensors, file_pairs, constant_folding_inputs


[docs]def constant_folding( sorted_graph: List[Tensor], workdir: str, model_name: str, ) -> Tuple[List[Tensor], List[Tuple[str, str]], List[Tensor]]: """ Fold and propagate constants. This pass looks for ops that have inputs which can be determined at compile time. It evaluates them, then puts the new constants back into the graph with bound data. The old ops are eliminated. This pass actually compiles and runs an AIT runtime. If there are any problems (e.g. due to buggy ops), the constant folding is aborted and the graph is returned unchanged. All generated code is stored in workdir/constant_folding. """ new_constants, file_pairs, constant_folding_inputs = _constant_folding_impl( sorted_graph, workdir, model_name ) # Replace ops with their folded values. for idx, tensor in enumerate(sorted_graph): name = tensor._attrs["name"] if name in new_constants: new_tensor = new_constants[name] replace_tensor(tensor, new_tensor) sorted_graph[idx] = new_tensor # Eliminate constants that are no longer used compiler.transform.remove_unused_ops(sorted_graph) return ( compiler.transform.transform_utils.sanitize_sorted_graph(sorted_graph), file_pairs, constant_folding_inputs, )