Source code for aitemplate.utils.graph_utils

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
import json
import logging
import os
from collections import deque
from pathlib import Path
from typing import Any, Dict, List, Set, Union

from aitemplate.utils.misc import is_debug
from aitemplate.utils.visualization import plot_graph


_LOGGER = logging.getLogger(__name__)


[docs]def get_sorted_ops(tensors) -> List[Any]: """ Produces the exact execution sequence of operators. This matches backend/codegen.py, ModelContainerGenerator.append_all_tensors() """ from aitemplate.compiler.base import Tensor visited = set() sorted_ops = [] if isinstance(tensors, Tensor): tensors = [tensors] for tensor in tensors: for src_op in tensor.src_ops(): if src_op in visited: continue visited.add(src_op) sorted_ops.append(src_op) return sorted_ops
def sorted_graph_debug_str(tensors) -> str: from aitemplate.compiler.base import Tensor if isinstance(tensors, Tensor): tensors = [tensors] tensor_str = "\n\n".join([str(tensor) for tensor in tensors]) op_str = "\n\n".join([str(op) for op in get_sorted_ops(tensors)]) return "Tensors: {}\n\nOperators: {}\n\n".format(tensor_str, op_str) def sorted_graph_debug_json(tensors) -> str: from aitemplate.compiler.base import Tensor from aitemplate.utils.json_utils import gen_unique_op_names, GraphJsonEncoder if isinstance(tensors, Tensor): tensors = [tensors] json_dict = {} json_dict["Tensors"] = tensors json_dict["Operators"] = get_sorted_ops(tensors) op_names = gen_unique_op_names(tensors) encoder = GraphJsonEncoder(op_names, indent=2) return encoder.encode(json_dict) def sorted_graph_pseudo_code(tensors, with_shape=True) -> str: from aitemplate.compiler.base import Tensor if isinstance(tensors, Tensor): tensors = [tensors] op_str = "\n".join([op.pseudo_code(with_shape) for op in get_sorted_ops(tensors)]) return op_str def sorted_op_pseudo_code(ops, with_shape=True) -> str: from aitemplate.compiler.base import Operator if isinstance(ops, Operator): ops = [ops] op_str = "\n".join([op.pseudo_code(with_shape) for op in ops]) return op_str def dump_graph_debug_str_to_file(tensors, workdir, name, file_with_time_profiles=None): if is_debug(): # Dump graph and pseudo code for debug only prefix = os.path.join(workdir, name) graph_path = prefix + "_graph.txt" graph_json_path = prefix + "_graph.json" pseudo_code_path = prefix + "_pseudo_code.txt" graph_visual_path = prefix + "_graph_vis.html" with open(graph_path, "w") as f: f.write(sorted_graph_debug_str(tensors)) _LOGGER.debug(f"Dumped {name} graph to {graph_path}") with open(graph_json_path, "w") as f: f.write(sorted_graph_debug_json(tensors)) _LOGGER.debug(f"Dumped {name} graph to {graph_json_path}") with open(pseudo_code_path, "w") as f: f.write(sorted_graph_pseudo_code(tensors)) _LOGGER.debug(f"Dumped {name} pseudo code to {pseudo_code_path}") plot_graph(tensors, graph_visual_path, file_with_time_profiles) _LOGGER.debug(f"Dumped {name} visualization to {graph_visual_path}") class TimestampTracking: def __init__( self, execution_start: float = 0, duration: float = 0, execution_order: int = 0 ): self.execution_order = execution_order self.execution_start = execution_start self.duration = duration @property def execution_end(self): return self.execution_start + self.duration class ProfiledTimeStatistics: def __init__(self): # Dict[Operator, float] self.op_durations = {} # Dict[Operator, TimestampTracking] self.op_parallel_trackers = {} # Dict[Operator, TimestampTracking] self.op_sequential_trackers = {} # Dict[Tensor, TimestampTracking] self.tensor_parallel_trackers = {} # Dict[Tensor, TimestampTracking] self.tensor_sequential_trackers = {} # 0.7 percentile of op times self.duration_p70 = 0.0 # 0.9 percentile of op times self.duration_p90 = 0.0 # 0.95 percentile of op times self.duration_p95 = 0.0 # max time spent among operators self.duration_max = 0.0 # total time spent by operators self.total_duration = 0.0 def _load_op_durations_from_file(input: Union[str, Path]) -> Dict[str, float]: """ Loads benchmarking results produced with a profiler from a .json file. """ if isinstance(input, str): input_path = Path(input) elif isinstance(input, Path): input_path = input else: raise ValueError("str or Path is needed as an input argument") # load the file with the profile. with input_path.open("r") as f: perf_per_op_str = f.read() # parse file perf_per_op_str_dict = json.loads(perf_per_op_str) op_durations: Dict[str, float] = {} for op_name, op_data in perf_per_op_str_dict.items(): op_durations[op_name] = op_data["ms_per_iter"] # done return op_durations def track_graph_timings( tensors, inputv: Union[str, Path, Dict[str, float]] ) -> ProfiledTimeStatistics: """ Traverses the graph of tensors and uses the statistics from the profiler to evaluate execution times in case of sequential execution (1 stream) and parallel execution (unlimited number of streams). The parallel execution tracking works in the following way. 1. Input tensors and constant tensors are marked as processed. 2. Other tensors are marked as unprocessed. 3. All operators are marked as unprocessed. 4. Repeat 4.1. Searches for unprocessed operators whose input tensors are marked as processed and "executes" ones, then mark corresponding output tensors as processed. 4.2. Stop if the number of processed operators on step 4.1 is zero 5. If the total number of unprocessed operators is not zero, then the graph is invalid. Parameters ---------- tensors : List[Tensor] a list of output Tensors of AIT graph inputv : Union[str, Path, Dict[str, float]] str or Path: a path to .json file with the results generated by a profiling procedure Dict[str, float]: time costs of operators (key is op._attrs["original_name"]) """ from aitemplate.compiler.base import Operator, Tensor output = ProfiledTimeStatistics() # the exact sequence of non-constant tensors that need to be evaluated # within a single execution stream. unprocessed_tensors: List[Tensor] = [] # Sequence_of_ops contains an exact execution sequence of ops # within a single execution stream. # Similar to graph_utils.py, get_sorted_ops() call. sequence_of_ops: List[Operator] = [] visited_ops: Set[Operator] = set() for tensor in tensors: src_ops = tensor.src_ops() if len(src_ops) == 0: # This tensor depends on no operator. # So, add the final statistics for it. output.tensor_parallel_trackers[tensor] = TimestampTracking() output.tensor_sequential_trackers[tensor] = TimestampTracking() else: for op in src_ops: if op not in visited_ops: visited_ops.add(op) sequence_of_ops.append(op) # this tensor needs to be evaluated unprocessed_tensors.append(tensor) # ok, we've got ops. Load the file with the profile. op_durations: Dict[str, float] = {} if isinstance(inputv, str) or isinstance(inputv, Path): # str or Path op_durations = _load_op_durations_from_file(inputv) elif ( isinstance(inputv, dict) and all(isinstance(x, str) for x in inputv.keys()) and all(isinstance(x, float) for x in inputv.values()) ): # this is Dict[str, float] op_durations = inputv else: raise ValueError("Invalid type of inputv") # map timings to ops for op in visited_ops: # profiler records the results under the original_name op_name = op._attrs["original_name"] # replace op_name with a unique name, if provided if op_name is not None: if op_name not in op_durations: # op_name was not found in the profiler report output.op_durations[op] = 0 else: time_cost = op_durations[op_name] output.op_durations[op] = time_cost else: # op_name is None, idk what to do output.op_durations[op] = 0 # compute statistics sorted_op_durations = sorted(op_durations.values()) if len(sorted_op_durations) > 0: output.duration_p70 = sorted_op_durations[int(len(sorted_op_durations) * 0.7)] output.duration_p90 = sorted_op_durations[int(len(sorted_op_durations) * 0.9)] output.duration_p95 = sorted_op_durations[int(len(sorted_op_durations) * 0.95)] output.duration_max = sorted_op_durations[-1] output.total_duration = sum(sorted_op_durations) # proceed with sequential execution: unprocessed_seq_ops = deque(sequence_of_ops) unprocessed_seq_tensors = deque(unprocessed_tensors) global_timestamp = 0.0 execution_step = 0 while len(unprocessed_seq_ops) > 0 or len(unprocessed_seq_tensors) > 0: # process operators n_local_processed_ops = 0 for op in unprocessed_seq_ops: depends_on = op._attrs["inputs"] # are all prereqs complete? can_proceed = all( tensor in output.tensor_sequential_trackers for tensor in depends_on ) if can_proceed: # yes. This operator is ready to be executed. execution_step += 1 op_duration = output.op_durations[op] output.op_sequential_trackers[op] = TimestampTracking( execution_start=global_timestamp, duration=op_duration, execution_order=execution_step, ) # modify global clock global_timestamp += op_duration n_local_processed_ops += 1 else: # cannot go ahead, some tensors need to be marked as processed break for _ in range(0, n_local_processed_ops): unprocessed_seq_ops.popleft() # process tensors n_local_processed_tensors = 0 for tensor in unprocessed_seq_tensors: depends_on = tensor.src_ops() # are all prereqs complete? can_proceed = all(op in output.op_sequential_trackers for op in depends_on) if can_proceed: # yes. The tensor computation is finished. max_execution_end = max( output.op_sequential_trackers[op].execution_end for op in depends_on ) max_execution_order = max( output.op_sequential_trackers[op].execution_order for op in depends_on ) output.tensor_sequential_trackers[tensor] = TimestampTracking( execution_start=max_execution_end, duration=0.0, execution_order=max_execution_order, ) n_local_processed_tensors += 1 else: # cannot proceed, some ops needs to be run first break for _ in range(0, n_local_processed_tensors): unprocessed_seq_tensors.popleft() # are we done? if n_local_processed_ops == 0 and n_local_processed_tensors == 0: # yes, no operators or tensors were processed on the current step. # This does not imply that all operators and tensors were processed. # Basically, this is a kinda early termination verification that # indicates that there is some invalid profiler / graph data. # So, we're trying to avoid infinite loops. break # process with parallel execution unprocessed_par_ops = set(sequence_of_ops) unprocessed_par_tensors = set(unprocessed_tensors) execution_step = 0 while len(unprocessed_par_ops) > 0 or len(unprocessed_par_tensors) > 0: # process operators new_processed_ops: Set[Operator] = set() for op in unprocessed_par_ops: depends_on = op._attrs["inputs"] # are all prereqs complete? can_proceed = all( tensor in output.tensor_parallel_trackers for tensor in depends_on ) if can_proceed: # yes. This operator is ready to be executed. op_duration = output.op_durations[op] if not depends_on: # a case of an operator that depends on no tensors max_execution_end = 0 else: # regular case max_execution_end = max( output.tensor_parallel_trackers[tensor].execution_end for tensor in depends_on ) output.op_parallel_trackers[op] = TimestampTracking( execution_start=max_execution_end, duration=op_duration, execution_order=execution_step, ) new_processed_ops.add(op) # ok, there were some processed operators if len(new_processed_ops) > 0: for op in new_processed_ops: unprocessed_par_ops.remove(op) execution_step += 1 # process tensors new_processed_tensors: Set[Tensor] = set() for tensor in unprocessed_par_tensors: depends_on = tensor.src_ops() # are all prereqs complete? can_proceed = all(op in output.op_parallel_trackers for op in depends_on) if can_proceed: # yes. The tensor computation is finished. max_execution_end = max( output.op_parallel_trackers[op].execution_end for op in depends_on ) max_execution_order = max( output.op_parallel_trackers[op].execution_order for op in depends_on ) output.tensor_parallel_trackers[tensor] = TimestampTracking( execution_start=max_execution_end, duration=0.0, execution_order=max_execution_order, ) new_processed_tensors.add(tensor) for tensor in new_processed_tensors: unprocessed_par_tensors.remove(tensor) # are we done? if len(new_processed_ops) == 0 and len(new_processed_tensors) == 0: # Same story: we're trying to avoid infinite loops. break # done return output
[docs]def split_simple_multistream_parallel_ops(ops_by_order, max_parallel_ops: int): """ Make sure that no more than max_parallel_ops operators are run in parallel. Say, on the first step op1, op2 and op3 can be executed in parallel. On the second one, it is op4 and op5. On the third one it is op6, op7, op8, op9. Then, ops_by_order is something like { 1: [op1, op2, op3], 2: [op4, op5], 3: [op6, op7, op8, op9] } Given max_parallel_ops=2, the output will be: [[op1, op2], [op3], [op4, op5], [op6, op7], [op8, op9]] Parameters ---------- ops_by_order : Dict[int, List[Operator]] A dictionary, its keys represent the execution order and its values represent operators that are executed in parallel. max_parallel_ops : int Number of operators that are allowed to be run in parallel Output : List[List[Operator]] transformed sequence of operators to execute. """ assert max_parallel_ops > 0 # todo: a better splitting algorithm can be implemented, # the one that splits operators into max_parallel_ops buckets # so that the amount of needed memory is about the same. # use priority_queue for this and iteratively add to the # bucket that has the lowest 'assigned' memory. output = [] execution_orders = sorted(ops_by_order.keys()) for execution_order in execution_orders: ops = ops_by_order[execution_order] ops_parallel = [] for op in ops: ops_parallel.append(op) if len(ops_parallel) >= max_parallel_ops: output.append(ops_parallel) ops_parallel = [] if len(ops_parallel) > 0: output.append(ops_parallel) # done return output