# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Rocm target specialization.
"""
# pylint: disable=W0702,W0707,W0611,C0415
import json
import logging
import os
import re
import shutil
import sys
from typing import List
from aitemplate.backend import registry
from aitemplate.backend.target import (
AIT_STATIC_FILES_PATH,
COMPOSABLE_KERNEL_PATH,
Target,
)
from aitemplate.utils import environ
from aitemplate.utils.misc import is_linux
# pylint: disable=W0613
_LOGGER = logging.getLogger(__name__)
[docs]class ROCM(Target):
"""ROCM target.
Parameters
----------
Target : Target
All attributes needed for ROCM.
"""
def __init__(
self,
template_path=COMPOSABLE_KERNEL_PATH,
arch="GFX908",
ait_static_files_path=AIT_STATIC_FILES_PATH,
**kwargs,
):
"""Initialize ROCM target.
Parameters
----------
template_path : str, optional
Path to composable kernel library, by default "${repo_root}/3rdparty/composable_kernel".
ait_static_files_path : str
Absolute path to the AIT static/ directory
arch : str, optional
Supported ROCM architecture, by default "GFX908".
"""
super().__init__(ait_static_files_path)
self._target_type = 2
self._template_path = template_path
self._arch = arch
self._kwargs = kwargs
self._compile_options = self._build_compile_options()
def _pkg_path(self):
"""Initialize package target.
Returns
-------
str
path to rocm compiler library
"""
rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm")
return rocm_path
def _get_ck_paths(self) -> List[str]:
ck_paths = [
os.path.join(self._template_path),
os.path.join(self._template_path, "include/"),
os.path.join(self._template_path, "external/include/half/"),
os.path.join(self._template_path, "library/include/"),
os.path.join(self._template_path, "profiler/include/"),
]
return ck_paths
[docs] def get_include_directories(self) -> List[str]:
return self._get_ck_paths()
def _build_compile_options(self):
"""Build compilation commands, including compilation flag library and includes.
Returns
-------
List
List of compilation options.
Raises
------
RuntimeError
Unsupported GPU Arch.
"""
ck_paths = self._get_ck_paths()
options = [
environ.get_compiler_opt_level(),
"-fPIC",
"-fvisibility=hidden",
"-std=c++17",
"-w",
"-DCK_TIME_KERNEL=0",
"-Xclang -mlink-builtin-bitcode -Xclang {}/amdgcn/bitcode/oclc_abi_version_400.bc".format(
self._pkg_path()
),
]
if self._arch in {"GFX908", "gfx908"}:
options.append("-DCK_AMD_GPU_GFX908")
options.append("--offload-arch=gfx908")
elif self._arch in {"GFX90a", "gfx90a"}:
options.append("-DCK_AMD_GPU_GFX90A")
options.append("--offload-arch=gfx90a")
else:
raise RuntimeError("Unsupported GPU Arch")
for path in ck_paths:
options.append("-I" + path)
options.append("-I" + os.path.join(self.static_files_path, "include"))
rocrand_path = os.path.join(self._pkg_path(), "rocrand/lib/")
options.append("-L" + rocrand_path)
options.append("-lrocrand")
if self._ndebug == 1:
options.append("-DNDEBUG")
return " ".join(options)
def _gen_ck_lib_pkg(self):
"""Build composable kernel python library.
Raises
------
RuntimeError
Failed to create ck library.
"""
self.lib_folder = None
try:
import ck_lib # noqa: F401
except BaseException:
try:
cur_path = os.path.dirname(os.path.realpath(__file__))
ck_lib_path = os.path.normpath(
os.path.join(cur_path, "..", "..", "utils", "mk_ck_lib")
)
f_make_lib = registry.get("rocm.make_ck_lib")
dst_path = f_make_lib(ck_lib_path)
sys.path.insert(1, dst_path)
except BaseException as err:
raise RuntimeError("Failed to create ck library") from err
self.lib_folder = dst_path
def __enter__(self):
"""Generate the ck library and generate ck operations."""
super().__enter__()
# Generate library.
self._gen_ck_lib_pkg()
# Choose the right ops to launch.
f_gen_ops = registry.get("rocm.gen_ck_ops")
self._operators = f_gen_ops(self._arch)
def __exit__(self, ptype, value, trace):
"""Delete the ck library."""
super().__exit__(ptype, value, trace)
if self.lib_folder and os.path.exists(self.lib_folder):
shutil.rmtree(self.lib_folder)
[docs] def cc(self):
return "hipcc"
[docs] def compile_cmd(self, executable=False):
"""Compile commands.
Parameters
----------
executable : bool, optional
Flag of whether to generate executable or obj, by default False.
Returns
-------
str
Full commands for compilation.
"""
if executable:
cmd = self.cc() + " " + self._compile_options + " -o {target} {src}"
else:
cmd = (
self.cc() + " " + self._compile_options + " -x hip -c -o {target} {src}"
)
return cmd
[docs] def src_extension(self):
return ".cpp"
[docs] def dev_select_flag(self):
return "HIP_VISIBLE_DEVICES"
[docs] def select_minimal_algo(self, algo_names: List[str]):
def comp_func(name):
compute_args = re.findall(r"_(\d+)_*", name)
if len(compute_args) != 1:
raise RuntimeError("Invalid ck op name")
args = [int(x) for x in compute_args[0]]
if "Gemm" in name:
if "GemmPadding" in name:
args.insert(0, 0)
if "GemmDefault" in name:
args.insert(0, 1)
elif "Conv" in name:
if "ConvFwdDefault" in name:
args.insert(0, 0)
else:
args.insert(0, 1)
else:
raise RuntimeError("Unknown CK ops.")
return tuple(args)
return min(algo_names, key=comp_func)
[docs]class FBROCM(ROCM):
"""ROCM target.
Parameters
----------
Target : Target
All attributes needed for ROCM.
"""
def __init__(
self,
template_path=COMPOSABLE_KERNEL_PATH,
arch="GFX90a",
ait_static_files_path=AIT_STATIC_FILES_PATH,
**kwargs,
):
"""Initialize ROCM target.
Parameters
----------
template_path : str, optional
Path to composable kernel library, by default "${repo_root}/3rdparty/composable_kernel".
ait_static_files_path : str
Absolute path to the AIT static/ directory
arch : str, optional
Supported ROCM architecture, by default "GFX90a".
"""
from libfb.py import parutil
self._template_path = template_path.replace("3rdparty", "fb/3rdparty")
convert_hippcc_json = parutil.get_file_path(
os.path.join("aitemplate/testing", "convert_hipcc_cmd")
)
_LOGGER.info(f"Load the hipcc compile option from {convert_hippcc_json}")
with open(convert_hippcc_json, "r") as hipcc_options_json:
self.hipcc_options_json = json.load(hipcc_options_json)
super().__init__(template_path=self._template_path, arch=arch, **kwargs)
def _build_compile_options(self):
"""Build compilation commands, including compilation flag library and includes.
Returns
-------
List
List of compilation options.
Raises
------
RuntimeError
Unsupported GPU Arch.
"""
ck_paths = self._get_ck_paths()
options = self.hipcc_options_json["args"] + [
environ.get_compiler_opt_level(),
"-fPIC",
"-fvisibility=hidden",
"-std=c++17",
"-w",
"-DCK_TIME_KERNEL=0",
"--hip-version=5.2.0",
]
for path in ck_paths:
options.append("-I" + path)
if self._arch in {"GFX908", "gfx908"}:
options.append("-DCK_AMD_GPU_GFX908")
options.append("--cuda-gpu-arch=gfx908")
elif self._arch in {"GFX90a", "gfx90a"}:
options.append("-DCK_AMD_GPU_GFX90A")
options.append("--cuda-gpu-arch=gfx90a")
else:
raise RuntimeError("Unsupported GPU Arch")
for path in ck_paths:
options.append("-I" + path)
options.append("-lrocrand")
return " ".join(options)
[docs] def binary_compile_cmd(self):
"""
There is no ld by default in the prod env. Instead, we use ld from the gvfs path.
"""
ld = self.hipcc_options_json["ld"]
objcopy = self.hipcc_options_json["objcopy"]
cmd = " ".join([ld, "-r -b binary -o {target} {src}"])
# Support models with >2GB constants on Linux only
if is_linux():
cmd += (
f" && {objcopy} --rename-section"
" .data=.lrodata,alloc,load,readonly,data,contents"
" {target} {target}"
)
return cmd
[docs] def cc(self):
return self.hipcc_options_json["hipcc_bin"]
[docs] def compile_options(self):
return self._compile_options
@registry.reg("fb.rocm.create_target")
def create_target_fb(arch, **kwargs):
return FBROCM(arch=arch, **kwargs)
@registry.reg("rocm.create_target")
def create_target(template_path, arch, **kwargs):
return ROCM(template_path=template_path, arch=arch, **kwargs)