How to add an operator to the AIT codegen

This tutorial will demonstrate how to add a new operator to the AIT codegen. Full source code can be found at examples/07_how_to_run_pt_model/how_to_run_pt_model.py.

0. Prerequisites

We need to import necessary Python modules:

from typing import Any, Dict, List

import jinja2
import torch

from aitemplate import backend
from aitemplate.backend import registry
from aitemplate.backend.backend_spec import CUDASpec, ROCMSpec
from aitemplate.compiler import compile_model
from aitemplate.compiler.base import IntVar, Operator, Tensor
from aitemplate.testing import detect_target

1. Define the operator graph node

Graph nodes are usually defined at aitemplate/compiler/ops.

class add_one(Operator):
  def __init__(self):
    super().__init__()
    # required, unique identity of operator category
    self._attrs["op"] = "add_one"
    # we can put whatever we want into the op attrs for later use
    self._attrs["has_profiler"] = False
    self._attrs["nop"] = False

  def __call__(self, x: Tensor) -> Tensor:
    # each operator needs to keep a record of input tensors
    self._attrs["inputs"] = [x]
    # optional, to set depth of the op based on inputs' depth, used in DFS
    self._set_depth()
    # infer output shape
    output_shape = self._infer_shape(x)
    # create output Tensor, of which the source op is the current op
    output = Tensor(output_shape, src_ops={self})
    # remember current op's outputs
    self._attrs["outputs"] = [output]
    return output

  def _infer_shape(self, x) -> List[IntVar]:
    # infer output shape
    # In case of we need infer shape in C++ side, we will create a jinja2 template
    # for shape inference function, and render to Python code in graph node
    # and render the template into C++ code in codegen
    return x.shape()

  def gen_function(self) -> str:
    # this function will be used in codegen
    # here we only need to redirect to backend codegen function
    target = backend.target.Target.current()
    func_key = f"{target.name()}.{self._attrs['op']}.gen_function"
    func = registry.get(func_key)
    return func(self._attrs)

Note

  • _attrs in Operator is the most important data structure for codegen.

  • _attrs[“op”] is the identity of operator category, which is used to find the corresponding codegen function in the backend; must be unique.

2. Define the necessary templates for Codegen

In AIT, there are 4 important templates for codegen:

  • FUNC_TEMPLATE: the template for generating the function body of the operator, and invoke GPU kernel in the body.

  • FUNC_SIGNATURE_TEMPLATE: the template for generating the function signature of the operator. The signature defines the name and arguments of the function.

  • FUNC_CALL_TEMPLATE: the template for generating the function call of the operator. The call will be used during inference to invoke the GPU kernel with given arguments.

  • FUNC_DECL: the template for forward declaration of the operator function. This is usually an alias of FUNC_SIGNATURE_TEMPLATE.

FUNC_TEMPLATE = jinja2.Template(
    """
{{header_files}}
namespace {
{{kernel}}
}  // namespace
{{func_signature}}
{
    invoke_add_one(output, input, num_elements, stream);
}
    """
)

FUNC_SIGNATURE = jinja2.Template(
    """
void {{func_name}}(half* output,
          const half* input,
          const int64_t num_elements,
          {{prefix}}Stream_t stream)
    """
)

FUNC_DECL = jinja2.Template(
    """
    {{func_signature}};
    """
)


FUNC_CALL_TEMPLATE = jinja2.Template(
    """
{{indent}}int64_t num_elements = 1;
{% for dim_name in dim_names %}
{{indent}}num_elements *= {{dim_name}};
{% endfor %}
{{indent}}{{func_name}}(
{{indent}}   {{output}}, {{input}}, num_elements, stream /* default stream */
{{indent}});
    """
)

3. Create the GPU kernels

In this example we use a simplest add one kernel. The kernel can be written by hand (as what programmer is expected to do), or generated by other tools.

KERNEL_TEMPLATE = jinja2.Template(
    """
__global__ void add_one(half* output, const half* input, const int64_t num_elements) {
  const int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < num_elements) {
    output[idx] = input[idx] + half(1.0);
  }
}
void invoke_add_one(half* output, const half* input, int64_t num_elements, {{prefix}}Stream_t stream) {
  if (num_elements < 1024) {
    dim3 grid(1);
    dim3 block(num_elements);
    add_one<<<grid, block, 0, stream>>>(output, input, num_elements);
  } else {
    dim3 grid((num_elements + 1024 - 1) / 1024);
    dim3 block(1024);
    add_one<<<grid, block, 0, stream>>>(output, input, num_elements);
  }
}
    """
)

(Optional) We also provide a helper function to handle CUDA/ROCm float16 data type difference.

FUNC_CALL_FP16_PARAM_TEMPLATE = jinja2.Template(
    """reinterpret_cast<half*>(
    {% if is_cuda %}&({% endif %}{{name}}{% if is_cuda %}->raw()){% endif %})"""
)

4. Define the codegen function

The codegen function is the function that renders the templates we defined into valid C++ code string. The codegen function will take func_attrs from the graph node, and fill in the jinja2 template.

def gen_function_call(func_attrs: Dict[str, Any], indent="  ", is_cuda=False) -> str:
  assert len(func_attrs["outputs"]) == 1
  assert len(func_attrs["inputs"]) == 1

  output_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render(
      name=func_attrs["outputs"][0]._attrs["name"], is_cuda=is_cuda
  )
  input_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render(
      name=func_attrs["inputs"][0]._attrs["name"], is_cuda=is_cuda
  )

  dim_names = [dim._attrs["name"] for dim in func_attrs["inputs"][0].shape()]
  return FUNC_CALL_TEMPLATE.render(
    func_name=func_attrs["name"],
    output=output_name,
    input=input_name,
    dim_names=dim_names,
    indent=indent,
  )


def gen_function(func_attrs: Dict[str, Any], header_files: str, backend_spec) -> str:
  prefix = backend_spec.prefix
  return FUNC_TEMPLATE.render(
    header_files=header_files,
    kernel=KERNEL_TEMPLATE.render(prefix=prefix),
    func_signature=FUNC_SIGNATURE.render(
        func_name=func_attrs["name"], prefix=prefix
    ),
  )


def gen_function_decl(func_attrs: Dict[str, Any], backend_spec) -> str:
  return FUNC_DECL.render(
    func_signature=FUNC_SIGNATURE.render(
        func_name=func_attrs["name"],
        prefix=backend_spec.prefix,
    ).strip()
  )

5.1 Register the codegen function in CUDA backend

CUDA backend functions are usually defined at aitemplate/backend/cuda/.

CUDA_HEADER_FILES = """
#include <cuda_fp16.h>
"""


@registry.reg("cuda.add_one.gen_function")
def cuda_add_one_gen_function(func_attrs: Dict[str, Any]) -> str:
  return gen_function(func_attrs, CUDA_HEADER_FILES, CUDASpec())


@registry.reg("cuda.add_one.func_decl")
def cuda_add_one_gen_function_decl(func_attrs: Dict[str, Any]) -> str:
  return gen_function_decl(func_attrs, CUDASpec())


@registry.reg("cuda.add_one.func_call")
def cuda_add_one_gen_function_call(func_attrs: Dict[str, Any], indent="  ") -> str:
  return gen_function_call(func_attrs, indent, is_cuda=True)

5.2 (Optional) Register the codegen function to ROCm backend

ROCm backend functions are usually defined at aitemplate/backend/rocm/.

HIP_HEADER_FILES = """
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
"""


@registry.reg("rocm.add_one.gen_function")
def rocm_add_one_gen_function(func_attrs: Dict[str, Any]) -> str:
  return gen_function(func_attrs, HIP_HEADER_FILES, ROCMSpec())


@registry.reg("rocm.add_one.func_decl")
def rocm_add_one_gen_function_decl(func_attrs: Dict[str, Any]) -> str:
  return gen_function_decl(func_attrs, ROCMSpec())


@registry.reg("rocm.add_one.func_call")
def rocm_add_one_gen_function_call(func_attrs: Dict[str, Any], indent="  ") -> str:
  return gen_function_call(func_attrs, indent, is_cuda=False)

6. Compile and verify the results with PyTorch

def create_ait_model(shapes):
  X = Tensor(
    shape=shapes,
    dtype="float16",
    name="X",
    is_input=True,
  )
  Y = add_one()(X)
  Y._attrs["is_output"] = True
  Y._attrs["name"] = "Y"
  return Y


def verify_add_one():
  shapes = [16, 512]
  x = torch.randn(shapes).cuda().half()
  y_pt = x + 1.0

  Y = create_ait_model([16, 512])
  target = detect_target()
  with compile_model(Y, target, "./tmp", "add_one") as module:
    y = torch.empty(shapes).cuda().half()
    inputs = {"X": x}
    outputs = {"Y": y}
    module.run_with_tensors(inputs, outputs)
    print(torch.allclose(y, y_pt, atol=1e-2, rtol=1e-2))