How to inference a PyTorch model with AIT

This tutorial will demonstrate how to inference a PyTorch model with AIT. Full source code can be found at examples/07_how_to_run_pt_model/how_to_run_pt_model.py.

0. Prerequisites

We need to import necessary Python modules:

import torch

from aitemplate.compiler import compile_model
from aitemplate.frontend import nn, Tensor
from aitemplate.testing import detect_target
from aitemplate.testing.benchmark_pt import benchmark_torch_function
from aitemplate.utils.graph_utils import sorted_graph_pseudo_code

1. Define a PyTorch module

Here we define a PyTorch model which is commonly seen in Transformers:

class PTSimpleModel(torch.nn.Module):
  def __init__(self, hidden, eps: float = 1e-5):
    super().__init__()
    self.dense1 = torch.nn.Linear(hidden, 4 * hidden)
    self.act1 = torch.nn.functional.gelu
    self.dense2 = torch.nn.Linear(4 * hidden, hidden)
    self.layernorm = torch.nn.LayerNorm(hidden, eps=eps)

  def forward(self, input):
    hidden_states = self.dense1(input)
    hidden_states = self.act1(hidden_states)
    hidden_states = self.dense2(hidden_states)
    hidden_states = hidden_states + input
    hidden_states = self.layernorm(hidden_states)
    return hidden_states

2. Define an AIT module

We can define a similar AIT module as follows:

class AITSimpleModel(nn.Module):
  def __init__(self, hidden, eps: float = 1e-5):
    super().__init__()
    self.dense1 = nn.Linear(hidden, 4 * hidden, specialization="fast_gelu")
    self.dense2 = nn.Linear(4 * hidden, hidden)
    self.layernorm = nn.LayerNorm(hidden, eps=eps)

  def forward(self, input):
    hidden_states = self.dense1(input)
    hidden_states = self.dense2(hidden_states)
    hidden_states = hidden_states + input
    hidden_states = self.layernorm(hidden_states)
    return hidden_states

Warning

The nn.Module API in AIT looks similar to PyTorch, but it is not the same.

The fundamental difference is that AIT module is a container to build a graph, while PyTorch module is a container to store parameters for eager. Which means, each AIT module’s forward method can be only called once, and the graph is built during the first call. If you want to share parameters, you need to use the compiler.ops instead. The compiler.ops is similar to functional in PyTorch.

AITemplate supports automatic fusion of linear followed by other operators. However in many cases, especially for quick iterations, we use manual specialization to specify the fused operator. For example, specialization=”fast_gelu” will fuse linear with the fast_gelu operator.

3. Define a helper function to map PyTorch parameters to AIT parameters

In AIT, all names must follow the C variable naming standard, because the names will be used in the codegen process.

def map_pt_params(ait_model, pt_model):
  ait_model.name_parameter_tensor()
  pt_params = dict(pt_model.named_parameters())
  mapped_pt_params = {}
  for name, _ in ait_model.named_parameters():
    ait_name = name.replace(".", "_")
    assert name in pt_params
    mapped_pt_params[ait_name] = pt_params[name]
  return mapped_pt_params

Warning

  • Different to PyTorch, it is required to call ait_model .name_parameter_tensor() method to provide each parameter with a name with a direct map to PyTorch.

  • Because all names in AIT must follow the C variable naming standard, you can easily replace . by _ or use a regular expression to make sure the name in valid.

  • For networks with conv + bn subgraph, we currently don’t provide an automatic pass to fold it. Please refer to our ResNet and Detectron2 examples to see how we handle CNN layout transform and BatchNorm folding.

4. Create PyTorch module, inputs/outputs

batch_size=1024
hidden=512
# create pt model
pt_model = PTSimpleModel(hidden).cuda().half()

# create pt input
x = torch.randn([batch_size, hidden]).cuda().half()

# run pt model
pt_model.eval()
y_pt = pt_model(x)

5. Create AIT module, inputs/outputs

batch_size=1024
hidden=512
# create AIT model
ait_model = AITSimpleModel(hidden)
# create AIT input Tensor
X = Tensor(
      shape=[batch_size, hidden],
      name="X",
      dtype="float16",
      is_input=True,
)
# run AIT module to generate output tensor
Y = ait_model(X)
# mark the output tensor
Y._attrs["is_output"] = True
Y._attrs["name"] = "Y"

Warning

  • Similar to MetaTensor, LazyTensor and a lot of other lazy evaluation frameworks, AIT’s Tensor records the computation graph, and the graph is built when the Tensor is compiled.

  • For input tensor, it is required to set the attribute is_input=True.

  • For output tensor, it is required to set the attribute Y._attrs[“is_output”] = True.

  • For input and output tensors, it is better to provide the name attributes to use in runtime.

6. Compile AIT module into runtime and do verification

# map pt weights to ait
weights = map_pt_params(ait_model, pt_model)

# codegen
target = detect_target()
with compile_model(
    Y, target, "./tmp", "simple_model_demo", constants=weights
) as module:
  # create storage for output tensor
  y = torch.empty([batch_size, hidden]).cuda().half()

  # inputs and outputs dict
  inputs = {"X": x}
  outputs = {"Y": y}

  # run
  module.run_with_tensors(inputs, outputs, graph_mode=True)

  # verify output is correct
  print(torch.allclose(y, y_pt, atol=1e-2, rtol=1e-2))

  # benchmark ait and pt
  count = 1000
  ait_t, _, _ = module.benchmark_with_tensors(
      inputs, outputs, graph_mode=True, count=count
  )
  print(f"AITemplate time: {ait_t} ms/iter")

  pt_t = benchmark_torch_function(count, pt_model.forward, x)
  print(f"PyTorch eager time: {pt_t} ms/iter")

In this example, AIT will automatically fuse GELU and elementwise addition into the TensorCore/MatrixCore gemm operation. On RTX-3080, in the example AIT is about 1.15X faster than PyTorch Eager.

Note

  • In this example, we fold the parameters (weights) into AIT runtime. The final dynamic library will contain them as parameters.

  • If during the compile time we don’t provide the parameters (for example, because the total parameters size is greater than 2GB), we can always call set_constant function in the runtime. Please check the runtime API for the details.