How to visualize an AIT model
Visualization is important for understanding the behavior of a model optimization. In AIT, we modify the codegen a little bit, from generating CUDA/HIP C++ code to HTML/Javascript code, then we can generate a visualization of the model.
The following code will generate a visualization of our first example.
1. Define the AIT Model
from aitemplate import compiler
from aitemplate.frontend import nn, Tensor
from aitemplate.testing import detect_target
from aitemplate.utils.visualization import plot_graph
class AITSimpleModel(nn.Module):
def __init__(self, hidden, eps: float = 1e-5):
super().__init__()
self.dense1 = nn.Linear(hidden, 4 * hidden, specialization="fast_gelu")
self.dense2 = nn.Linear(4 * hidden, hidden)
self.layernorm = nn.LayerNorm(hidden, eps=eps)
def forward(self, input):
hidden_states = self.dense1(input)
hidden_states = self.dense2(hidden_states)
hidden_states = hidden_states + input
hidden_states = self.layernorm(hidden_states)
return hidden_states
def gen_ait_model():
batch_size = 512
hidden = 1024
ait_model = AITSimpleModel(hidden)
ait_model.name_parameter_tensor()
X = Tensor(
shape=[batch_size, hidden],
name="X",
dtype="float16",
is_input=True,
)
Y = ait_model(X)
Y._attrs["is_output"] = True
Y._attrs["name"] = "Y"
return Y
output_tensor = gen_ait_model()
2. Apply optimizations on the AIT Model
def apply_optimizations(tensors):
target = detect_target()
# first, convert output tensors to graph
with target:
graph = compiler.transform.toposort(tensors)
# second, provide names to the graph
compiler.transform.name_graph(graph)
compiler.transform.mark_param_tensor(graph)
compiler.transform.mark_special_views(graph)
# we can apply optimizations to the graph, or test single optimization pass on the graph
graph = compiler.transform.optimize_graph(graph, "./tmp")
return graph
graph = apply_optimizations(output_tensor)
3. Generate visualization
# Plot the graph
plot_graph(graph, file_path="ait_model.html")
The visualization will be generated in the “ait_model.html” file. This file can be opened in Chrome without any web server.