Source code for aitemplate.backend.target

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
Target object for AITemplate.
"""
import logging
import os
import pathlib
import shutil
import tempfile
from enum import IntEnum
from typing import Any, Dict, List, Optional, Tuple

from aitemplate.backend import registry
from aitemplate.backend.profiler_cache import ProfileCacheDB
from aitemplate.utils.misc import is_linux


_LOGGER = logging.getLogger(__name__)

_MYPATH = os.path.dirname(os.path.realpath(__file__))
_3RDPARTY_PATH = os.path.normpath(os.path.join(_MYPATH, "..", "..", "..", "3rdparty"))
_WHEEL_3RDPARTY_PATH = os.path.normpath(os.path.join(_MYPATH, "..", "3rdparty"))
if os.path.exists(_WHEEL_3RDPARTY_PATH):
    _3RDPARTY_PATH = _WHEEL_3RDPARTY_PATH
AIT_STATIC_FILES_PATH = os.path.join(_3RDPARTY_PATH, "../static")
CUTLASS_PATH = os.path.join(_3RDPARTY_PATH, "cutlass")
COMPOSABLE_KERNEL_PATH = os.path.join(_3RDPARTY_PATH, "composable_kernel")
CUB_PATH = os.path.join(_3RDPARTY_PATH, "cub")

CURRENT_TARGET = None


[docs]class TargetType(IntEnum): """Enum for target type.""" cuda = 1 rocm = 2
class Target: def __init__(self, static_files_path: str): """ Parameters ---------- static_files_path : str Absolute path to the AIT static/ directory """ self._target_type = -1 self._template_path = "" self._compile_cmd = "" self._cache_path = "" self._profile_cache = None self.static_files_path = static_files_path ndebug_str = os.getenv("AIT_NDEBUG", "1") try: self._ndebug = int(ndebug_str) except ValueError: self._ndebug = 0 def __enter__(self): """Enter the target context manager. This will set CURRENT_TARGET to this target. Raises ------ RuntimeError If CURRENT_TARGET is already set, this will raise a RuntimeError. """ self._load_profile_cache() global CURRENT_TARGET if CURRENT_TARGET is not None: raise RuntimeError("Target has been set.") assert self._target_type > 0 CURRENT_TARGET = self def __exit__(self, ptype, value, trace): """Exit the target context manager.""" self._profile_cache = None global CURRENT_TARGET CURRENT_TARGET = None @staticmethod def current(): """Obtain the current target. Returns ------- Target return the current target object. Raises ------ RuntimeError If no target is set, this will raise a RuntimeError. """ if CURRENT_TARGET is None: raise RuntimeError("Target is not set yet.") return CURRENT_TARGET def template_path(self) -> str: """Return CUTLASS/CK path for this target. Returns ------- str Absolute path to the CUTLASS/CK template directory. """ return self._template_path def get_custom_libs(self, absolute_dir, filename) -> str: filename = os.path.join(absolute_dir, filename) with open(filename) as f: res = f.read() return res def name(self) -> str: """Return the name of the target. Returns ------- str The name of the target. """ return TargetType(self._target_type).name def cc(self): """Compiler for this target. Raises ------ NotImplementedError Need to be implemented by subclass. """ raise NotImplementedError def make(self): make_path = shutil.which("make") return make_path if make_path is not None else "make" def cmake(self): cmake_path = shutil.which("cmake") return cmake_path if cmake_path is not None else "cmake" def compile_cmd(self, executable: bool = False): """Compile command string template for this target. Parameters ---------- executable : bool, optional Whether the command with compile an executable object by default False Raises ------ NotImplementedError Need to be implemented by subclass. """ raise NotImplementedError def binary_compile_cmd(self): """ A command that turns a raw binary file into an object file that can be linked into the executable. """ cmd = "ld -r -b binary -o {target} {src}" # Support models with >2GB constants on Linux only if is_linux(): cmd += ( " && objcopy --rename-section" " .data=.lrodata,alloc,load,readonly,data,contents" " {target} {target}" ) return cmd def compile_options(self) -> str: """Options for compiling the target. Returns ------- str """ return "" def src_extension(self): """Source file extension for this target. Returns ------- str Source file extension for this target. """ return NotImplementedError def dev_select_flag(self): """Environment variable to select the device. Returns ------- str Environment variable to select the device. """ return NotImplementedError def apply_op_rules(self, op_def): """Apply special rules to change template op definition Parameters ---------- op_def : str Operator definition code string Returns ------- str Modified op definition code string """ return op_def def select_minimal_algo(self, algo_names: List[str]): """Select the minimal algorithm from the list of algorithms. This is used in CI to speed up the test without running actually profiling. Parameters ---------- algo_names : List[str] All the available algorithm names for selection. """ return NotImplementedError def trick_ci_env(self) -> bool: """Check if we want to trick in_ci_env to make it False. This is used in workers where we do not have control of CI_FLAG Returns ------- bool Whether to trick ci env. """ return os.environ.get("TRICK_CI_ENV", None) == "1" def in_ci_env(self) -> bool: """Check if the current environment is CI. Returns ------- bool Returns True if env CI_FLAG=CIRCLECI and TRICK_CI_ENV is not set (or 0). """ return os.environ.get("CI_FLAG", None) == "CIRCLECI" and not self.trick_ci_env() def disable_profiler_codegen(self) -> bool: """Whether to disable profiler codegen. disable profiler codegen completely in CI to speed up long running unittest Returns ------- bool Whether to disable profiler codegen. """ return ( os.environ.get("DISABLE_PROFILER_CODEGEN", None) == "1" and not self.force_profile() ) def force_profile(self) -> bool: """Whether to force profile. Force profiling regardless in_ci_env, disable_profiler_codegen Returns ------- bool Whether to force profile. """ return os.environ.get("FORCE_PROFILE", None) == "1" def use_dummy_profiling_results(self) -> bool: """Whether to use dummy profiling results.""" # Whether to use dummy profiling results to speed up runs. return self.in_ci_env() and not self.force_profile() def _get_cache_file_name(self) -> str: """Get the cache file name for this target. Returns ------- str The cache file name for this target. """ # TODO: Add device name cache_file = "{dev_type}.db".format(dev_type=TargetType(self._target_type).name) return cache_file def _prepare_profile_cache_path(self) -> Optional[str]: """Prepare local profile cache for this target.""" if self.use_dummy_profiling_results(): _LOGGER.info("Escape loading profile cache when using dummy profiling") return None prefix = None if os.environ.get("CACHE_DIR", None): prefix = os.environ.get("CACHE_DIR", None) cache_file = self._get_cache_file_name() if prefix is None: prefix = os.path.join(pathlib.Path.home(), ".aitemplate") try: os.makedirs(prefix, exist_ok=True) except OSError as error: _LOGGER.info(f"Cannot mkdir at {prefix} due to issue {error}") prefix = os.path.join(tempfile.mkdtemp(prefix="aitemplate_"), ".aitemplate") os.makedirs(prefix, exist_ok=True) _LOGGER.info(f"mkdir at {prefix} instead") cache_path = os.path.join(prefix, cache_file) flush_flag = os.environ.get("FLUSH_PROFILE_CACHE", "0") if flush_flag != "0": os.remove(cache_path) return cache_path def _load_profile_cache(self): """Load local profile cache for this target.""" self._cache_path = self._prepare_profile_cache_path() if self._cache_path is None: return _LOGGER.info(f"Loading profile cache from: {self._cache_path}") self._profile_cache = ProfileCacheDB( TargetType(self._target_type).name, path=self._cache_path ) def get_profile_cache_path(self): """Get local profile cache path for this target.""" return self._cache_path def get_profile_cache_version(self, op_class: str) -> int: """Get the current profile cache version for the op_class. Parameters ---------- op_class : str Op class name: only gemm is supported at the moment. Returns ------- int cache version. Raises ------ NotImplementedError If op class is not supported, raise error. """ # TODO: support conv and normalization if op_class == "gemm": return self._profile_cache.gemm_cache_version elif op_class == "conv": return self._profile_cache.conv_cache_version elif op_class == "conv3d": return self._profile_cache.conv3d_cache_version raise NotImplementedError def query_profile_cache( self, op_class: str, args: Dict[str, Any] ) -> Tuple[str, int]: """Query the profile cache for the given op class and args. Parameters ---------- op_class : str Op class name. gemm, conv or normalization args : Dict[str, Any] Op arguments. Returns ------- Tuple[str, int] Queried best profiling results. Raises ------ NotImplementedError If op class is not supported, raise error. """ if op_class == "gemm": return self._profile_cache.query_gemm(args) if op_class == "conv": return self._profile_cache.query_conv(args) if op_class == "conv3d": return self._profile_cache.query_conv3d(args) if op_class == "normalization": return self._profile_cache.query_normalization(args) raise NotImplementedError def insert_profile_cache(self, op_class: str, args: Dict[str, Any]): """Insert the profile cache for the given op class and args.""" if op_class == "gemm": self._profile_cache.insert_gemm(args) elif op_class == "conv": self._profile_cache.insert_conv(args) elif op_class == "conv3d": self._profile_cache.insert_conv3d(args) elif op_class == "normalization": self._profile_cache.insert_normalization(args) else: raise NotImplementedError def copy_headers_and_csrc_to_workdir(self, workdir: str) -> List[str]: """ Copy over all the files in include/ and csrc/ to some working directory. Skips files that are not marked with .cpp/.h Returns a list of copied source files (to be built later). Parameters ---------- workdir : str The path to copy to """ sources = [] csrc = os.path.join(self.static_files_path, "csrc") for fname in os.listdir(csrc): fname_dst, ext = os.path.splitext(fname) if ext != ".cpp": continue # TODO: Remove this file when the linker error gets fixed in rocm backend. # All files in csrc should be shared between the ROCM and CUDA backends. if fname == "rocm_hack.cpp" and self.name() != "rocm": continue fname_src = os.path.join(csrc, fname) fname_dst_cpp = os.path.join(workdir, f"{fname_dst}{self.src_extension()}") shutil.copyfile(fname_src, fname_dst_cpp) sources.append(fname_dst_cpp) headers = [] include = os.path.join(self.static_files_path, "include") for fname in os.listdir(include): _, ext = os.path.splitext(fname) if ext != ".h": continue fname_src = os.path.join(include, fname) fname_dst = os.path.join(workdir, fname) shutil.copyfile(fname_src, fname_dst) headers.append(fname_dst) return sources @classmethod def remote_logger(cls, record: Dict[str, Any]) -> None: """ Upload the record remotely to some logging table. Parameters ---------- record : Dict[str, Any] The dictionary storing the record """ return def get_include_directories(self) -> List[str]: """ Returns a list of include directories for a compiler. Raises ------ NotImplementedError Need to be implemented by subclass. """ raise NotImplementedError def get_host_compiler_options(self) -> List[str]: """ Returns a list of options for the host compiler. Raises ------ NotImplementedError Need to be implemented by subclass. """ raise NotImplementedError def get_device_compiler_options(self) -> List[str]: """ Returns a list of options for the device compiler. Raises ------ NotImplementedError Need to be implemented by subclass. """ raise NotImplementedError def postprocess_build_dir(self, build_dir: str) -> None: """ Postprocess a build directory, allows final modification of the build directory before building. """ pass
[docs]def CUDA(template_path: str = CUTLASS_PATH, arch: str = "80", **kwargs): """Create a CUDA target.""" func = registry.get("cuda.create_target") return func(template_path, arch, **kwargs)
[docs]def ROCM(template_path: str = COMPOSABLE_KERNEL_PATH, arch: str = "gfx908", **kwargs): """Create a ROCM target.""" func = registry.get("rocm.create_target") return func(template_path, arch, **kwargs)