How to add an operator to the AIT codegen
This tutorial will demonstrate how to add a new operator to the AIT codegen. Full source code can be found at examples/07_how_to_run_pt_model/how_to_run_pt_model.py.
0. Prerequisites
We need to import necessary Python modules:
from typing import Any, Dict, List
import jinja2
import torch
from aitemplate import backend
from aitemplate.backend import registry
from aitemplate.backend.backend_spec import CUDASpec, ROCMSpec
from aitemplate.compiler import compile_model
from aitemplate.compiler.base import IntVar, Operator, Tensor
from aitemplate.testing import detect_target
1. Define the operator graph node
Graph nodes are usually defined at aitemplate/compiler/ops.
class add_one(Operator):
def __init__(self):
super().__init__()
# required, unique identity of operator category
self._attrs["op"] = "add_one"
# we can put whatever we want into the op attrs for later use
self._attrs["has_profiler"] = False
self._attrs["nop"] = False
def __call__(self, x: Tensor) -> Tensor:
# each operator needs to keep a record of input tensors
self._attrs["inputs"] = [x]
# optional, to set depth of the op based on inputs' depth, used in DFS
self._set_depth()
# infer output shape
output_shape = self._infer_shape(x)
# create output Tensor, of which the source op is the current op
output = Tensor(output_shape, src_ops={self})
# remember current op's outputs
self._attrs["outputs"] = [output]
return output
def _infer_shape(self, x) -> List[IntVar]:
# infer output shape
# In case of we need infer shape in C++ side, we will create a jinja2 template
# for shape inference function, and render to Python code in graph node
# and render the template into C++ code in codegen
return x.shape()
def gen_function(self) -> str:
# this function will be used in codegen
# here we only need to redirect to backend codegen function
target = backend.target.Target.current()
func_key = f"{target.name()}.{self._attrs['op']}.gen_function"
func = registry.get(func_key)
return func(self._attrs)
Note
_attrs in Operator is the most important data structure for codegen.
_attrs[“op”] is the identity of operator category, which is used to find the corresponding codegen function in the backend; must be unique.
2. Define the necessary templates for Codegen
In AIT, there are 4 important templates for codegen:
FUNC_TEMPLATE: the template for generating the function body of the operator, and invoke GPU kernel in the body.
FUNC_SIGNATURE_TEMPLATE: the template for generating the function signature of the operator. The signature defines the name and arguments of the function.
FUNC_CALL_TEMPLATE: the template for generating the function call of the operator. The call will be used during inference to invoke the GPU kernel with given arguments.
FUNC_DECL: the template for forward declaration of the operator function. This is usually an alias of FUNC_SIGNATURE_TEMPLATE.
FUNC_TEMPLATE = jinja2.Template(
"""
{{header_files}}
namespace {
{{kernel}}
} // namespace
{{func_signature}}
{
invoke_add_one(output, input, num_elements, stream);
}
"""
)
FUNC_SIGNATURE = jinja2.Template(
"""
void {{func_name}}(half* output,
const half* input,
const int64_t num_elements,
{{prefix}}Stream_t stream)
"""
)
FUNC_DECL = jinja2.Template(
"""
{{func_signature}};
"""
)
FUNC_CALL_TEMPLATE = jinja2.Template(
"""
{{indent}}int64_t num_elements = 1;
{% for dim_name in dim_names %}
{{indent}}num_elements *= {{dim_name}};
{% endfor %}
{{indent}}{{func_name}}(
{{indent}} {{output}}, {{input}}, num_elements, stream /* default stream */
{{indent}});
"""
)
3. Create the GPU kernels
In this example we use a simplest add one kernel. The kernel can be written by hand (as what programmer is expected to do), or generated by other tools.
KERNEL_TEMPLATE = jinja2.Template(
"""
__global__ void add_one(half* output, const half* input, const int64_t num_elements) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_elements) {
output[idx] = input[idx] + half(1.0);
}
}
void invoke_add_one(half* output, const half* input, int64_t num_elements, {{prefix}}Stream_t stream) {
if (num_elements < 1024) {
dim3 grid(1);
dim3 block(num_elements);
add_one<<<grid, block, 0, stream>>>(output, input, num_elements);
} else {
dim3 grid((num_elements + 1024 - 1) / 1024);
dim3 block(1024);
add_one<<<grid, block, 0, stream>>>(output, input, num_elements);
}
}
"""
)
(Optional) We also provide a helper function to handle CUDA/ROCm float16 data type difference.
FUNC_CALL_FP16_PARAM_TEMPLATE = jinja2.Template(
"""reinterpret_cast<half*>(
{% if is_cuda %}&({% endif %}{{name}}{% if is_cuda %}->raw()){% endif %})"""
)
4. Define the codegen function
The codegen function is the function that renders the templates we defined into valid C++ code string. The codegen function will take func_attrs from the graph node, and fill in the jinja2 template.
def gen_function_call(func_attrs: Dict[str, Any], indent=" ", is_cuda=False) -> str:
assert len(func_attrs["outputs"]) == 1
assert len(func_attrs["inputs"]) == 1
output_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render(
name=func_attrs["outputs"][0]._attrs["name"], is_cuda=is_cuda
)
input_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render(
name=func_attrs["inputs"][0]._attrs["name"], is_cuda=is_cuda
)
dim_names = [dim._attrs["name"] for dim in func_attrs["inputs"][0].shape()]
return FUNC_CALL_TEMPLATE.render(
func_name=func_attrs["name"],
output=output_name,
input=input_name,
dim_names=dim_names,
indent=indent,
)
def gen_function(func_attrs: Dict[str, Any], header_files: str, backend_spec) -> str:
prefix = backend_spec.prefix
return FUNC_TEMPLATE.render(
header_files=header_files,
kernel=KERNEL_TEMPLATE.render(prefix=prefix),
func_signature=FUNC_SIGNATURE.render(
func_name=func_attrs["name"], prefix=prefix
),
)
def gen_function_decl(func_attrs: Dict[str, Any], backend_spec) -> str:
return FUNC_DECL.render(
func_signature=FUNC_SIGNATURE.render(
func_name=func_attrs["name"],
prefix=backend_spec.prefix,
).strip()
)
5.1 Register the codegen function in CUDA backend
CUDA backend functions are usually defined at aitemplate/backend/cuda/.
CUDA_HEADER_FILES = """
#include <cuda_fp16.h>
"""
@registry.reg("cuda.add_one.gen_function")
def cuda_add_one_gen_function(func_attrs: Dict[str, Any]) -> str:
return gen_function(func_attrs, CUDA_HEADER_FILES, CUDASpec())
@registry.reg("cuda.add_one.func_decl")
def cuda_add_one_gen_function_decl(func_attrs: Dict[str, Any]) -> str:
return gen_function_decl(func_attrs, CUDASpec())
@registry.reg("cuda.add_one.func_call")
def cuda_add_one_gen_function_call(func_attrs: Dict[str, Any], indent=" ") -> str:
return gen_function_call(func_attrs, indent, is_cuda=True)
5.2 (Optional) Register the codegen function to ROCm backend
ROCm backend functions are usually defined at aitemplate/backend/rocm/.
HIP_HEADER_FILES = """
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
"""
@registry.reg("rocm.add_one.gen_function")
def rocm_add_one_gen_function(func_attrs: Dict[str, Any]) -> str:
return gen_function(func_attrs, HIP_HEADER_FILES, ROCMSpec())
@registry.reg("rocm.add_one.func_decl")
def rocm_add_one_gen_function_decl(func_attrs: Dict[str, Any]) -> str:
return gen_function_decl(func_attrs, ROCMSpec())
@registry.reg("rocm.add_one.func_call")
def rocm_add_one_gen_function_call(func_attrs: Dict[str, Any], indent=" ") -> str:
return gen_function_call(func_attrs, indent, is_cuda=False)
6. Compile and verify the results with PyTorch
def create_ait_model(shapes):
X = Tensor(
shape=shapes,
dtype="float16",
name="X",
is_input=True,
)
Y = add_one()(X)
Y._attrs["is_output"] = True
Y._attrs["name"] = "Y"
return Y
def verify_add_one():
shapes = [16, 512]
x = torch.randn(shapes).cuda().half()
y_pt = x + 1.0
Y = create_ait_model([16, 512])
target = detect_target()
with compile_model(Y, target, "./tmp", "add_one") as module:
y = torch.empty(shapes).cuda().half()
inputs = {"X": x}
outputs = {"Y": y}
module.run_with_tensors(inputs, outputs)
print(torch.allclose(y, y_pt, atol=1e-2, rtol=1e-2))