Source code for aitemplate.utils.visualization.plot

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
Graph visualization tool for AITemplate
"""

import json
import os

from typing import Optional

from aitemplate import compiler
from aitemplate.utils.environ import shorten_tensor_names_for_plots
from aitemplate.utils.misc import short_str
from aitemplate.utils.visualization import op_attr_factory, pydot
from aitemplate.utils.visualization.op_attr_factory import op_to_content
from aitemplate.utils.visualization.web_template import (
    INDEX_TEMPLATE,
    MODAL_TEMPLATE,
    TABLE_TEMPLATE,
)

COLOR_SCHEME = {
    "default_tensor": "lightskyblue1",
    "view": "plum1",
    "output": "violetred",
    "param": "x11gray",
    "default_op": "mediumpurple1",
}


def _get_tensor_shape_str(tensor) -> str:
    from aitemplate.compiler.base import IntImm

    shape = []
    tensor_shape = tensor.shape()
    for s in tensor_shape:
        if isinstance(s, IntImm):
            shape.append(s.value())
            continue

        # s is IntVar
        s = s._attrs["values"]
        if len(s) == 1:
            shape.append(s[0])
        elif len(s) >= 2:
            shape.append((min(s), max(s)))
        else:
            raise RuntimeError()
    return str(shape)


def _gen_tensor_modal(tensor) -> str:
    content = {}
    content["shape"] = _get_tensor_shape_str(tensor)
    content["is_view_of"] = (
        "None"
        if tensor._attrs["is_view_of"] is None
        else tensor._attrs["is_view_of"]._attrs["name"]
    )
    content["is_input"] = str(tensor._attrs["is_input"])
    content["is_output"] = str(tensor._attrs["is_output"])
    content["is_param"] = str(tensor._attrs["is_param"])
    content["dtype"] = str(tensor._attrs["dtype"])
    table_src = TABLE_TEMPLATE.render(table_data=content)
    modal_src = MODAL_TEMPLATE.render(
        modal_id=f'{tensor._attrs["name"]}_modal',
        modal_label=f'{tensor._attrs["name"]}_label',
        modal_title=tensor._attrs["name"],
        modal_content=table_src,
    )
    return modal_src


def _gen_op_modal(op) -> str:
    content = op_attr_factory.op_to_content(op)
    table_src = TABLE_TEMPLATE.render(table_data=content)
    modal_src = MODAL_TEMPLATE.render(
        modal_id=f'{op._attrs["name"]}_modal',
        modal_label=f'{op._attrs["name"]}_label',
        modal_title=op._attrs["name"],
        modal_content=table_src,
    )
    return modal_src


def _highlight_op_node(op_node, op, time_stats):
    if op in time_stats.op_durations:
        perf_op = time_stats.op_durations[op]
        scale_factor = float(perf_op) / float(time_stats.total_duration)

        if perf_op > time_stats.duration_p95:
            op_node.set("color", "maroon1")
            op_node.set("penwidth", 9)
            op_node.set("width", 1 + scale_factor * 100)
            op_node.set("height", 1 + scale_factor * 50)
        elif perf_op > time_stats.duration_p90:
            op_node.set("color", "magenta1")
            op_node.set("penwidth", 6)
            op_node.set("width", 1 + scale_factor * 100)
            op_node.set("height", 1 + scale_factor * 50)
        elif perf_op > time_stats.duration_p70:
            op_node.set("color", "mediumorchid1")
            op_node.set("penwidth", 3)
            op_node.set("width", 1 + scale_factor * 100)
            op_node.set("height", 1 + scale_factor * 50)


[docs]def plot_graph( tensors, file_path: str, file_with_time_profiles: Optional[str] = None ) -> None: """ Plot AIT graph. Parameters ---------- tensors : Union[Tensor, List[Tensor]] An output Tensor, or a list of output Tensors of AIT graph. file_path : str Output file path, currently we support the following extension: - html - format supported by graphviz file_with_time_profile : Optional[str] Adds time for every node, if provided """ dot_graph = pydot.Dot(graph_type="digraph") _, ext = os.path.splitext(file_path) if ext == "": raise ValueError("Please provide a file extension in path to plot.") ext = ext[1:] if ext != "html" and ext not in dot_graph.formats: raise ValueError(f"Unsupported extension '{ext}' to plot!") sorted_graph = compiler.transform.toposort(tensors) compiler.transform.name_graph(sorted_graph) # Before doing the further processing, it is needed # to find whether there is an Operator instance with the same # name like 'fused_elementwise_123' that is used # several times, but with different input and/or outputs. # In such a case, every Operator instance should get its unique # name. # # The following dict will be used to store such unique names, # such as 'fused_elementwise_123 0' and 'fused_elementwise_123 1'. from aitemplate.utils.json_utils import gen_unique_op_names op_names = gen_unique_op_names(sorted_graph) from aitemplate.utils.graph_utils import ProfiledTimeStatistics, track_graph_timings time_stats = ProfiledTimeStatistics() if file_with_time_profiles is not None: time_stats = track_graph_timings(sorted_graph, file_with_time_profiles) op_set = {} tensor_set = {} modal_set = [] items = [] popover_data = {} for tensor in sorted_graph: tensor_node = None tensor_name = tensor._attrs["name"] if tensor in tensor_set: tensor_node = tensor_set[tensor] else: color = COLOR_SCHEME["default_tensor"] if tensor._attrs["is_view_of"] is not None: color = COLOR_SCHEME["view"] if tensor._attrs["is_output"] is True: color = COLOR_SCHEME["output"] if tensor._attrs["is_param"] is True: color = COLOR_SCHEME["param"] label = tensor_name if shorten_tensor_names_for_plots(): if tensor_name is not None and len(tensor_name) > 30: label = short_str(tensor_name) # add a label with time label_with_time = "" seq_tracker = time_stats.tensor_sequential_trackers.get(tensor, None) if seq_tracker is not None and seq_tracker.execution_end != 0: label_with_time += f"{seq_tracker.execution_end:.3f} ms" par_tracker = time_stats.tensor_parallel_trackers.get(tensor, None) if par_tracker is not None and par_tracker.execution_end != 0: if label_with_time: label_with_time += " / " label_with_time += f"{par_tracker.execution_end:.3f} ms" if label_with_time: label = f"{tensor_name}\\n{label_with_time}" # add a node tensor_node = pydot.Node( name=tensor_name, shape="note", id=tensor_name, label=label, color=color, ) tensor_set[tensor] = tensor_node dot_graph.add_node(tensor_node) modal_set.append(_gen_tensor_modal(tensor)) items.append(tensor_name) popover_data[tensor_name] = f"shape: {_get_tensor_shape_str(tensor)}" for src_op in tensor.src_ops(): op_node = None op_name = src_op._attrs["name"] # replace op_name with a unique name, if provided if op_name is not None: op_name = op_names.get(src_op, op_name) if src_op in op_set: op_node = op_set[src_op] else: label = ( f"{op_name}\\n{str(time_stats.op_durations[src_op])} ms" if src_op in time_stats.op_durations else op_name ) op_node = pydot.Node( name=op_name, shape="folder", id=op_name, label=label, color="mediumpurple1", ) _highlight_op_node(op_node, src_op, time_stats) op_set[src_op] = op_node dot_graph.add_node(op_node) modal_set.append(_gen_op_modal(src_op)) items.append(op_name) popover_data[op_name] = ", ".join( [f"{x}: {y}" for x, y in op_to_content(src_op).items()] ) dot_graph.add_edge(pydot.Edge(op_node, tensor_node)) for dst_op in tensor.dst_ops(): op_node = None op_name = dst_op._attrs["name"] # replace op_name with a unique name, if provided if op_name is not None: op_name = op_names.get(dst_op, op_name) if dst_op in op_set: op_node = op_set[dst_op] else: label = ( f"{op_name}\\n{str(time_stats.op_durations[dst_op])} ms" if dst_op in time_stats.op_durations else op_name ) op_node = pydot.Node( name=op_name, shape="folder", id=op_name, label=label, color="mediumpurple1", ) _highlight_op_node(op_node, dst_op, time_stats) op_set[dst_op] = op_node dot_graph.add_node(op_node) items.append(op_name) popover_data[op_name] = ", ".join( [f"{x}: {y}" for x, y in op_to_content(dst_op).items()] ) # add modal modal_set.append(_gen_op_modal(dst_op)) dot_graph.add_edge(pydot.Edge(tensor_node, op_node)) file_dir = os.path.dirname(file_path) if file_dir: os.makedirs(file_dir, exist_ok=True) if ext == "html": basename = os.path.splitext(os.path.basename(file_path))[0] dot_src = dot_graph.to_string() modal_src = "\n".join(modal_set) items_src = json.dumps(items, indent=2) popover_src = json.dumps(popover_data, indent=2) index = INDEX_TEMPLATE.render( dot_src=dot_src, modals=modal_src, network_name=basename, items=items_src, popover_data=popover_src, ) with open(file_path, "w") as fo: fo.write(index) else: dot_graph.write(file_path, format=ext)