How to add an operator to the AIT codegen ========================================= This tutorial will demonstrate how to add a new operator to the AIT codegen. Full source code can be found at `examples/07_how_to_run_pt_model/how_to_run_pt_model.py`. 0. Prerequisites ---------------- We need to import necessary Python modules: .. code-block:: python from typing import Any, Dict, List import jinja2 import torch from aitemplate import backend from aitemplate.backend import registry from aitemplate.backend.backend_spec import CUDASpec, ROCMSpec from aitemplate.compiler import compile_model from aitemplate.compiler.base import IntVar, Operator, Tensor from aitemplate.testing import detect_target 1. Define the operator graph node --------------------------------- Graph nodes are usually defined at `aitemplate/compiler/ops`. .. code-block:: python class add_one(Operator): def __init__(self): super().__init__() # required, unique identity of operator category self._attrs["op"] = "add_one" # we can put whatever we want into the op attrs for later use self._attrs["has_profiler"] = False self._attrs["nop"] = False def __call__(self, x: Tensor) -> Tensor: # each operator needs to keep a record of input tensors self._attrs["inputs"] = [x] # optional, to set depth of the op based on inputs' depth, used in DFS self._set_depth() # infer output shape output_shape = self._infer_shape(x) # create output Tensor, of which the source op is the current op output = Tensor(output_shape, src_ops={self}) # remember current op's outputs self._attrs["outputs"] = [output] return output def _infer_shape(self, x) -> List[IntVar]: # infer output shape # In case of we need infer shape in C++ side, we will create a jinja2 template # for shape inference function, and render to Python code in graph node # and render the template into C++ code in codegen return x.shape() def gen_function(self) -> str: # this function will be used in codegen # here we only need to redirect to backend codegen function target = backend.target.Target.current() func_key = f"{target.name()}.{self._attrs['op']}.gen_function" func = registry.get(func_key) return func(self._attrs) .. note:: - `_attrs` in Operator is the most important data structure for codegen. - `_attrs["op"]` is the identity of operator category, which is used to find the corresponding codegen function in the backend; must be **unique**. 2. Define the necessary templates for Codegen --------------------------------------------- In AIT, there are 4 important templates for codegen: - `FUNC_TEMPLATE`: the template for generating the function body of the operator, and invoke GPU kernel in the body. - `FUNC_SIGNATURE_TEMPLATE`: the template for generating the function signature of the operator. The signature defines the name and arguments of the function. - `FUNC_CALL_TEMPLATE`: the template for generating the function call of the operator. The call will be used during inference to invoke the GPU kernel with given arguments. - `FUNC_DECL`: the template for forward declaration of the operator function. This is usually an alias of `FUNC_SIGNATURE_TEMPLATE`. .. code-block:: python FUNC_TEMPLATE = jinja2.Template( """ {{header_files}} namespace { {{kernel}} } // namespace {{func_signature}} { invoke_add_one(output, input, num_elements, stream); } """ ) FUNC_SIGNATURE = jinja2.Template( """ void {{func_name}}(half* output, const half* input, const int64_t num_elements, {{prefix}}Stream_t stream) """ ) FUNC_DECL = jinja2.Template( """ {{func_signature}}; """ ) FUNC_CALL_TEMPLATE = jinja2.Template( """ {{indent}}int64_t num_elements = 1; {% for dim_name in dim_names %} {{indent}}num_elements *= {{dim_name}}; {% endfor %} {{indent}}{{func_name}}( {{indent}} {{output}}, {{input}}, num_elements, stream /* default stream */ {{indent}}); """ ) 3. Create the GPU kernels ------------------------- In this example we use a simplest add one kernel. The kernel can be written by hand (as what programmer is expected to do), or generated by other tools. .. code-block:: python KERNEL_TEMPLATE = jinja2.Template( """ __global__ void add_one(half* output, const half* input, const int64_t num_elements) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < num_elements) { output[idx] = input[idx] + half(1.0); } } void invoke_add_one(half* output, const half* input, int64_t num_elements, {{prefix}}Stream_t stream) { if (num_elements < 1024) { dim3 grid(1); dim3 block(num_elements); add_one<<>>(output, input, num_elements); } else { dim3 grid((num_elements + 1024 - 1) / 1024); dim3 block(1024); add_one<<>>(output, input, num_elements); } } """ ) (Optional) We also provide a helper function to handle CUDA/ROCm float16 data type difference. .. code-block:: python FUNC_CALL_FP16_PARAM_TEMPLATE = jinja2.Template( """reinterpret_cast( {% if is_cuda %}&({% endif %}{{name}}{% if is_cuda %}->raw()){% endif %})""" ) 4. Define the codegen function ------------------------------ The codegen function is the function that renders the templates we defined into valid C++ code string. The codegen function will take `func_attrs` from the graph node, and fill in the jinja2 template. .. code-block:: python def gen_function_call(func_attrs: Dict[str, Any], indent=" ", is_cuda=False) -> str: assert len(func_attrs["outputs"]) == 1 assert len(func_attrs["inputs"]) == 1 output_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render( name=func_attrs["outputs"][0]._attrs["name"], is_cuda=is_cuda ) input_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render( name=func_attrs["inputs"][0]._attrs["name"], is_cuda=is_cuda ) dim_names = [dim._attrs["name"] for dim in func_attrs["inputs"][0].shape()] return FUNC_CALL_TEMPLATE.render( func_name=func_attrs["name"], output=output_name, input=input_name, dim_names=dim_names, indent=indent, ) def gen_function(func_attrs: Dict[str, Any], header_files: str, backend_spec) -> str: prefix = backend_spec.prefix return FUNC_TEMPLATE.render( header_files=header_files, kernel=KERNEL_TEMPLATE.render(prefix=prefix), func_signature=FUNC_SIGNATURE.render( func_name=func_attrs["name"], prefix=prefix ), ) def gen_function_decl(func_attrs: Dict[str, Any], backend_spec) -> str: return FUNC_DECL.render( func_signature=FUNC_SIGNATURE.render( func_name=func_attrs["name"], prefix=backend_spec.prefix, ).strip() ) 5.1 Register the codegen function in CUDA backend ------------------------------------------------- CUDA backend functions are usually defined at `aitemplate/backend/cuda/`. .. code-block:: python CUDA_HEADER_FILES = """ #include """ @registry.reg("cuda.add_one.gen_function") def cuda_add_one_gen_function(func_attrs: Dict[str, Any]) -> str: return gen_function(func_attrs, CUDA_HEADER_FILES, CUDASpec()) @registry.reg("cuda.add_one.func_decl") def cuda_add_one_gen_function_decl(func_attrs: Dict[str, Any]) -> str: return gen_function_decl(func_attrs, CUDASpec()) @registry.reg("cuda.add_one.func_call") def cuda_add_one_gen_function_call(func_attrs: Dict[str, Any], indent=" ") -> str: return gen_function_call(func_attrs, indent, is_cuda=True) 5.2 (Optional) Register the codegen function to ROCm backend ------------------------------------------------------------ ROCm backend functions are usually defined at `aitemplate/backend/rocm/`. .. code-block:: python HIP_HEADER_FILES = """ #include #include """ @registry.reg("rocm.add_one.gen_function") def rocm_add_one_gen_function(func_attrs: Dict[str, Any]) -> str: return gen_function(func_attrs, HIP_HEADER_FILES, ROCMSpec()) @registry.reg("rocm.add_one.func_decl") def rocm_add_one_gen_function_decl(func_attrs: Dict[str, Any]) -> str: return gen_function_decl(func_attrs, ROCMSpec()) @registry.reg("rocm.add_one.func_call") def rocm_add_one_gen_function_call(func_attrs: Dict[str, Any], indent=" ") -> str: return gen_function_call(func_attrs, indent, is_cuda=False) 6. Compile and verify the results with PyTorch ---------------------------------------------- .. code-block:: python def create_ait_model(shapes): X = Tensor( shape=shapes, dtype="float16", name="X", is_input=True, ) Y = add_one()(X) Y._attrs["is_output"] = True Y._attrs["name"] = "Y" return Y def verify_add_one(): shapes = [16, 512] x = torch.randn(shapes).cuda().half() y_pt = x + 1.0 Y = create_ait_model([16, 512]) target = detect_target() with compile_model(Y, target, "./tmp", "add_one") as module: y = torch.empty(shapes).cuda().half() inputs = {"X": x} outputs = {"Y": y} module.run_with_tensors(inputs, outputs) print(torch.allclose(y, y_pt, atol=1e-2, rtol=1e-2))