Source code for aitemplate.compiler.transform.memory_planning

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
Graph pass for memory planning.
"""

import bisect
import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import List

from aitemplate.compiler.base import Operator, Tensor
from aitemplate.utils.environ import multistream_max_mem_parallel_ops, multistream_mode
from aitemplate.utils.graph_utils import split_simple_multistream_parallel_ops

# pylint: disable=C0103

_LOGGER = logging.getLogger(__name__)


@dataclass
class TensorUsageRecord:
    """
    A named tuple to keep a tensor usage record, where

    tensor: this tensor

    first_op_idx: the index of the first op that uses this tensor as its input
                  or output

    last_op_idx: the index of the last op that uses this tensor as its input or
                 output

    size: the size of this tensor
    """

    tensor: Tensor
    first_op_idx: int
    last_op_idx: int
    size: int

    def __iter__(self):
        return iter([self.tensor, self.first_op_idx, self.last_op_idx, self.size])


def _find_original_tensor(tensor: Tensor):
    """Find the original tensor of a tensor view recursively."""
    view = tensor._attrs["is_view_of"]
    if not view:
        return tensor
    return _find_original_tensor(view)


def _make_tensor_usage_records(sorted_ops: List[Operator]) -> List[TensorUsageRecord]:
    num_of_ops = len(sorted_ops)
    tensor_records = defaultdict(
        lambda: TensorUsageRecord(
            tensor=None, first_op_idx=num_of_ops, last_op_idx=-1, size=None
        )
    )
    for op_idx, op in enumerate(sorted_ops):
        for tensor in op._attrs["inputs"] + op._attrs["outputs"]:
            # Skip weights and inputs since we don't overwrite them.
            # Note that it might be OK to overwrite inputs, but let's be
            # consertative for now and not surprise users. We could always
            # make a flag to do that later if it's needed.
            if tensor._attrs["is_param"]:
                continue
            name = tensor._attrs["name"]
            this_tensor = tensor_records[name].tensor
            if this_tensor is None:
                tensor_records[name].tensor = tensor
            else:
                # make sure we didn't screw up anything
                assert (
                    tensor == this_tensor
                ), f"existing tensor: {this_tensor}, new tensor: {tensor}, op: {op}"

            first_op_idx = tensor_records[name].first_op_idx
            last_op_idx = tensor_records[name].last_op_idx
            tensor_records[name].first_op_idx = min(first_op_idx, op_idx)
            tensor_records[name].last_op_idx = max(last_op_idx, op_idx)
            # An output tensor's lifetime extends to the last op.
            if tensor._attrs["is_output"]:
                tensor_records[name].last_op_idx = num_of_ops - 1

            size = tensor_records[name].size
            tensor_size = tensor.size_bytes(alignment=64)
            if size is None:
                tensor_records[name].size = tensor_size
            else:
                # make sure we didn't screw up anything
                assert size == tensor_size

    # tensor views extend the lifetime of the original tensors
    tensor_views = []
    for name, tensor_record in tensor_records.items():
        this_tensor = tensor_record.tensor
        if this_tensor._attrs["is_view_of"]:
            orig_tensor = _find_original_tensor(this_tensor)
            # view of input
            if orig_tensor._attrs["is_param"]:
                continue
            orig_tensor_name = orig_tensor._attrs["name"]
            assert orig_tensor_name in tensor_records
            tensor_records[orig_tensor_name].last_op_idx = max(
                tensor_records[orig_tensor_name].last_op_idx, tensor_record.last_op_idx
            )
            tensor_views.append(name)

    # remove tensor views from tensor_records
    for name in tensor_views:
        del tensor_records[name]

    # sanity checks
    # make sure we have valid indices and sizes
    records = tensor_records.values()
    for tensor, first_op_idx, last_op_idx, size in records:
        assert tensor is not None
        assert 0 <= first_op_idx < num_of_ops
        assert 0 <= last_op_idx < num_of_ops
        assert first_op_idx <= last_op_idx
        assert size is not None

    return list(records)


def assign_offsets_to_views_and_outputs(sorted_graph: List[Tensor]) -> None:
    """Propagate offsets determined by the memory planning algorithm to views.

    Parameters
    ----------
    sorted_graph : List[Tensor]
        The graph, modified in-place
    """
    for node in sorted_graph:
        if node._attrs["is_view_of"]:
            node._attrs["offset"] = node._attrs["is_view_of"]._attrs["offset"]


[docs]@dataclass class Workspace: shared_size: int unique_size: int def total_size(self) -> int: return self.shared_size + self.unique_size
def _compute_workspace(sorted_graph: List[Tensor]) -> Workspace: """ Compute the workspace for the model, which can be used as scratch memory by ops. This pass examines two attributes on every function in the graph: - workspace: The amount of memory in bytes to be used as shared scratch memory. Here, "shared" means that other ops are allowed to write to this memory. - unique_workspace: The amount of memory in bytes to be used as exclusive scratch memory. If set, this pass will assign the op a "unique_workspace_offset". This can be used at codegen time to set a pointer to the region of exclusive shared memory. The returned Workspace has two attributes: - shared_size: The total memory needed for all op's shared scratch memory (i.e. the maximum of all workspace attributes) - unique_size: The total memory needed for all unique scratch memory (i.e. the sum of all unique_workspace attributes) During codegen, the workspace gets set up like this: [--unique 1--][--unique 2--]...[--unique N--][--shared--] """ unique_workspace_size = 0 max_workspace = 0 for node in sorted_graph: for func in node._attrs["src_ops"]: if "workspace" in func._attrs: max_workspace = max(max_workspace, func._attrs["workspace"]) if ( "unique_workspace" in func._attrs and "unique_workspace_offset" not in func._attrs ): func._attrs["unique_workspace_offset"] = unique_workspace_size unique_workspace_size += func._attrs["unique_workspace"] return Workspace(max_workspace, unique_workspace_size) def _greedy_by_size_memory_planning( sorted_graph: List[Tensor], tensor_usage_records: List[TensorUsageRecord] ): """ based on the greedy-by-size algorithm for offset calculation described in the following paper: Yury Pisarchyk, Juhyun Lee, Efficient Memory Management for Deep Neural Net Inference, https://arxiv.org/abs/2001.03288 """ # sort tensor usage records in non-increasing order by their sizes sorted_tensor_usage_records = sorted( tensor_usage_records, key=lambda r: r.size, reverse=True ) max_blob = 0 # For tensors that have been assigned, we keep their tensor usage records # in increasing order by memory offsets sorted_assigned_records = [] for tensor_record in sorted_tensor_usage_records: tensor, first_op_idx, last_op_idx, size = tensor_record prev_offset = 0 best_offset = None smallest_gap = pow(2, 63) - 1 # Iterate through tensors that have been allocated. # For those whose usage intervals intersect with that of current # tensor, we try to find the smallest valid memory gap between such two # allocated tensors, which is big enough to hold current tensor. # If such a gap is found, we will place current tensor in the gap. for a_record in sorted_assigned_records: a_tensor, a_first_op_idx, a_last_op_idx, a_size = a_record max_first_op_idx = max(first_op_idx, a_first_op_idx) min_last_op_idx = min(last_op_idx, a_last_op_idx) # current tensor overlaps with this assigned tensor if max_first_op_idx <= min_last_op_idx: a_offset = a_tensor._attrs["offset"] gap = a_offset - prev_offset if size <= gap < smallest_gap: smallest_gap = gap best_offset = prev_offset prev_offset = max(prev_offset, a_offset + a_size) # If we can't find a valid memory gap between two allocated tensors, # we put current tensor to the rightmost tensor whose usage interval # intersects with that of the current tensor. if best_offset is None: best_offset = prev_offset tensor._attrs["offset"] = best_offset max_blob = max(max_blob, best_offset + size) # bisect from Python <=3.9 doesn't have the key parameter sorted_offsets = [r.tensor._attrs["offset"] for r in sorted_assigned_records] in_pos = bisect.bisect_right( sorted_offsets, tensor_record.tensor._attrs["offset"] ) sorted_assigned_records.insert(in_pos, tensor_record) # now we assign blobs for weights and inputs constant_offset = 0 for node in sorted_graph: if ( node._attrs["data"] is not None or node._attrs["constant_folding_output_idx"] is not None ): node._attrs["offset"] = constant_offset constant_offset += node.size_bytes(alignment=64) # assign offsets to tensor views # this step must happen after weights and inputs are assigned so that views # of inputs are properly handled assign_offsets_to_views_and_outputs(sorted_graph) workspace = _compute_workspace(sorted_graph) # make sure we've covered the entire graph return (max_blob, constant_offset, workspace) def greedy_by_size_memory_planning(sorted_graph: List[Tensor]): # noqa: C901 """ based on the greedy-by-size algorithm for offset calculation described in the following paper: Yury Pisarchyk, Juhyun Lee, Efficient Memory Management for Deep Neural Net Inference, https://arxiv.org/abs/2001.03288 """ sorted_ops = [] for node in sorted_graph: sorted_ops.extend(node.src_ops()) tensor_usage_records = _make_tensor_usage_records(sorted_ops) return _greedy_by_size_memory_planning(sorted_graph, tensor_usage_records) def naive_memory_planning(sorted_graph: List[Tensor]): max_blob = 0 offset = 0 constant_offset = 0 for node in sorted_graph: if ( node._attrs["data"] is not None or node._attrs["constant_folding_output_idx"] is not None ): node._attrs["offset"] = constant_offset constant_offset += node.size_bytes(alignment=64) elif not node._attrs["is_view_of"]: node._attrs["offset"] = offset tensor_size = node.size_bytes(alignment=64) offset += tensor_size max_blob += tensor_size # workspace workspace = _compute_workspace(sorted_graph) assign_offsets_to_views_and_outputs(sorted_graph) return (max_blob, constant_offset, workspace) def _make_tensor_usage_records_simple_multistream( par_ops_seq: List[List[Operator]], ) -> List[TensorUsageRecord]: """ Generalized version of _make_tensor_usage_records() which assumes that several ops may be executed on every step. Simple multistream algo iteratively tracks sets of operators that can be run in parallel independently on each iteration. par_ops_seq contains lists of operators that can be run in parallel on every algorithm iteration. Technically, the regular _make_tensor_usage_records() version is similar to the following one: def _make_tensor_usage_records(sorted_ops): par_ops_seq = [sorted_ops] return _make_tensor_usage_records_simple_multistream(par_ops_seq) This version is kept as a separate one, because multistreaming feature is still somewhat experimental. """ num_of_ops = len(par_ops_seq) tensor_records = defaultdict( lambda: TensorUsageRecord( tensor=None, first_op_idx=num_of_ops, last_op_idx=-1, size=None ) ) for op_idx, par_ops in enumerate(par_ops_seq): for op in par_ops: for tensor in op._attrs["inputs"] + op._attrs["outputs"]: # Skip weights and inputs since we don't overwrite them. # Note that it might be OK to overwrite inputs, but let's be # consertative for now and not surprise users. We could always # make a flag to do that later if it's needed. if tensor._attrs["is_param"]: continue name = tensor._attrs["name"] this_tensor = tensor_records[name].tensor if this_tensor is None: tensor_records[name].tensor = tensor else: # make sure we didn't screw up anything assert ( tensor == this_tensor ), f"existing tensor: {this_tensor}, new tensor: {tensor}, op: {op}" first_op_idx = tensor_records[name].first_op_idx last_op_idx = tensor_records[name].last_op_idx tensor_records[name].first_op_idx = min(first_op_idx, op_idx) tensor_records[name].last_op_idx = max(last_op_idx, op_idx) # An output tensor's lifetime extends to the last op. if tensor._attrs["is_output"]: tensor_records[name].last_op_idx = num_of_ops - 1 size = tensor_records[name].size tensor_size = tensor.size_bytes(alignment=64) if size is None: tensor_records[name].size = tensor_size else: # make sure we didn't screw up anything assert size == tensor_size # tensor views extend the lifetime of the original tensors tensor_views = [] for name, tensor_record in tensor_records.items(): this_tensor = tensor_record.tensor if this_tensor._attrs["is_view_of"]: orig_tensor = _find_original_tensor(this_tensor) # view of input if orig_tensor._attrs["is_param"]: continue orig_tensor_name = orig_tensor._attrs["name"] assert orig_tensor_name in tensor_records tensor_records[orig_tensor_name].last_op_idx = max( tensor_records[orig_tensor_name].last_op_idx, tensor_record.last_op_idx ) tensor_views.append(name) # remove tensor views from tensor_records for name in tensor_views: del tensor_records[name] # sanity checks # make sure we have valid indices and sizes records = tensor_records.values() for tensor, first_op_idx, last_op_idx, size in records: assert tensor is not None assert 0 <= first_op_idx < num_of_ops assert 0 <= last_op_idx < num_of_ops assert first_op_idx <= last_op_idx assert size is not None return list(records)
[docs]def simple_multistream_memory_planning(sorted_graph: List[Tensor]): """ A specialized case for simple multi-stream execution. It uses more or slightly more GPU memory than greedy_by_size_memory_planner, depending on the input graph, but still significantly less than naive_memory_planning. """ from aitemplate.utils.graph_utils import track_graph_timings # track the sequence time_stats = track_graph_timings(sorted_graph, {}) # sort all operators by parallel execution order ops_by_order = defaultdict(list) for op, tracking in time_stats.op_parallel_trackers.items(): ops_by_order[tracking.execution_order].append(op) # convert Dict[int, List[Operator]] into List[List[Operator]] max_parallel_ops = multistream_max_mem_parallel_ops() par_ops_seq = split_simple_multistream_parallel_ops(ops_by_order, max_parallel_ops) tensor_usage_records = _make_tensor_usage_records_simple_multistream(par_ops_seq) return _greedy_by_size_memory_planning(sorted_graph, tensor_usage_records)
def proxy_memory_planning(sorted_graph: List[Tensor]): run_mode = multistream_mode() if run_mode == 0: # no multistream max_blob, constant_offset, workspace = greedy_by_size_memory_planning( sorted_graph ) elif run_mode == 1: # simple multistream max_blob, constant_offset, workspace = simple_multistream_memory_planning( sorted_graph ) else: # unsupported raise Exception(f"Unsupported multistream mode ({run_mode})") # print some statistics _LOGGER.info( f"Workspace shared_size={workspace.shared_size} unique_size={workspace.unique_size}" ) _LOGGER.info(f"max_blob={max_blob} constant_offset={constant_offset}") # done return (max_blob, constant_offset, workspace) # memory_planning = greedy_by_size_memory_planning # memory_planning = naive_memory_planning memory_planning = proxy_memory_planning