# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import hashlib
import logging
import os
import random
import re
import secrets
import shlex
import shutil
import tempfile
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from pathlib import Path
from typing import Callable, List, Optional, Tuple
from aitemplate.backend.target import Target
from aitemplate.utils import environ as aitemplate_env
from aitemplate.utils.io import file_age, touch
_LOGGER = logging.getLogger(__name__)
# File extensions to be considered source files
source_extensions = {
"cpp",
"h",
"cu",
"cuh",
"c",
"hpp",
"hxx",
"inl",
"py",
"cxx",
"cc",
"version",
"binhash",
"hash",
}
source_filenames = {
# needs to be lowercase, because everything is lowercased before comparison
# Filenames in here are considered source files, even if their extension would
# suggest they are cache artifacts
"makefile"
}
source_filename_prefixes = ["makefile"]
# File extensions of files to be considered cache artifacts ( unless they are considered source files )
# note: we're not caching .obj files anymore as these are not strictly necessary to keep.
cache_extensions = {"so", "dll", "exe", ""}
skip_cache_flag = False # Global flag that cache implementations should check whether
# the cache is enabled or not. Used by skip_build_cache decorator
class SkipBuildCache:
def __init__(self, context_skip_cache_flag: bool = True):
"""
Context manager to temporarily disable the build cache within an execution context.
"""
self.context_skip_cache_flag = context_skip_cache_flag
def __enter__(self):
global skip_cache_flag
self.old_skip_cache_flag = skip_cache_flag
skip_cache_flag = self.context_skip_cache_flag
def __exit__(self, *args, **kwargs):
global skip_cache_flag
skip_cache_flag = self.old_skip_cache_flag
def should_skip_build_cache():
"""
This function should be called by cache implementations to determine whether the cache should be skipped or not
"""
global skip_cache_flag
if skip_cache_flag:
return True
skip_percentage = aitemplate_env.ait_build_cache_skip_percentage()
if skip_percentage is not None:
skip_percentage = int(skip_percentage)
assert (
skip_percentage >= 0 and skip_percentage <= 100
), f"Skip percentage has to be in the range [0,100]. Actual value: {skip_percentage}"
if skip_percentage == 100:
return True
if skip_percentage == 0:
return False
rndi = random.randint(0, 99)
if rndi < skip_percentage:
return True
return False
def filename_norm_split(filename: str) -> Tuple[str, str]:
"""
Splits filename into basename and extension
and lowercases results to enable simple lookup
in a case-insensitive manner.
Args:
filename (str): Filename/Path to split
Returns:
Tuple[str,str]: file basename, file extension
"""
file_basename = os.path.basename(filename).lower()
file_parts = file_basename.split(".")
if len(file_parts) > 1:
file_ext = file_parts[-1]
else:
file_ext = ""
return file_basename, file_ext
def is_source(filename: str) -> bool:
"""
Simple filter function, returns true if the passed filename is considered
to be a source file (used to build the cache key) for the purpose of caching
Args:
filename (str): File path as a string
Returns:
bool: Whether the filename is a source file
"""
file_basename, file_ext = filename_norm_split(filename)
return (
(file_basename in source_filenames)
or (file_ext in source_extensions)
or any(file_basename.startswith(p) for p in source_filename_prefixes)
)
def is_cache_artifact(filename: str) -> bool:
"""
Simple filter function, returns true if the passed filename is considered
to be a cacheable artifact (not used to build cache key, but stored in cache)
for the purpose of caching
Args:
filename (str): File path as a string
Returns:
bool: Whether the filename is a cache artifact
"""
file_basename, file_ext = filename_norm_split(filename)
return not is_source(filename) and file_ext in cache_extensions
def is_bin_file(filename: str) -> bool:
"""
Simple filter function, returns true if the passed filename is considered
to be a bin file which needs to be considered for the purpose of creating
a cache-key, but may be deleted after an initial build.
bin files are hashed, and their hashes are kept in a small separete file
for future use when building the cache key. So the hash is not lost, even if the binary
file is deleted.
Args:
filename (str): File path as a string
Returns:
bool: Whether the filename is a binary file in the above sense
"""
return filename.lower().endswith(".bin")
def create_dir_hash(
cmds: List[str],
build_dir: str,
filter_func: Callable[[str], bool] = is_source,
debug=False,
content_replacer: Callable[[str], Optional[bytes]] = None,
) -> str:
"""Create a hash of the (source file) contents of a build directory, used for
creating a cache key of an entire directory along with the build commands.
Args:
cmds (List[str]): Build commands to be incorporated in hash key computation
build_dir (str): Path to build directory ( not part of hash )
filter_func (Callable[[str], bool], optional): Filter function which determines whether a given file is considered a source file or not. Defaults to is_source(path).
debug (bool, optional): Whether to write a 'cache_key.log' file into the build directory, so that cache misses can be debugged more easily. Defaults to False.
content_replacer (Callable[[Path], Optional[bytes]], optional): Content replacer is an optional function that may replace content of a file for hashing purposes. If None, or if this function returns None,
then no content replacement is done on the file.
Returns:
str: SHA256 Hash of the build directory contents in the form of a hexdigest string.
"""
hash_log = None
try:
if not os.path.isdir(build_dir):
return "empty_dir"
if debug:
hash_log = open( # noqa: P201 - this is actually closed properly in the finally close below
os.path.join(build_dir, "cache_key.log"), mode="a", encoding="utf8"
)
hash_log.write(f"Building dir hash of {build_dir}\n")
basepath = Path(build_dir)
files = [p.relative_to(basepath) for p in basepath.rglob("*") if not p.is_dir()]
hash_object = hashlib.sha256()
for cmd in cmds:
_cmd = cmd.replace(
build_dir, "${BUILD_DIR}"
) # Make sure we can cache regardless of the build directory location.
hash_object.update(_cmd.encode("utf-8"))
if debug:
hash_log.write(f"\tCOMMAND: {_cmd} -> {hash_object.hexdigest()}\n")
for fpath in sorted(files):
if not filter_func(str(fpath)):
continue
hash_object.update(str(fpath).encode("utf-8"))
fullpath = str(basepath / fpath)
replaced_content = None
if content_replacer is not None:
replaced_content = content_replacer(fullpath)
if replaced_content is not None:
hash_object.update(replaced_content)
else:
with open(fullpath, "rb") as f:
# read file in chunks of 32kb
# in order to support large files ( constants.obj )
while True:
chunk = f.read(1024 * 32)
if not chunk:
break
hash_object.update(chunk)
if debug:
hash_log.write(f"\t{str(fpath)} -> {hash_object.hexdigest()}\n")
if debug:
hash_log.write(
f"Final hash of {build_dir} is {hash_object.hexdigest().lower()}\n"
)
return hash_object.hexdigest().lower()
finally:
if hash_log:
hash_log.close()
[docs]def write_binhash_file(
build_dir,
binhash_filename="constants.hash",
filter_func: Callable[[str], bool] = is_bin_file,
):
"""Hash all binary input files, so we don't have to keep them ( Usecase: constants.obj / constants.bin )
Args:
build_dir (str): Path to build directory
binhash_filename (str, optional): File to be written within build_dir, defaults to "constants.hash".
filter_func (Callable[[str], bool], optional): Filter function to determine which files to hash. Defaults to is_bin_file.
"""
binhash = create_dir_hash([binhash_filename], build_dir, filter_func=filter_func)
with open(os.path.join(build_dir, binhash_filename), "w", encoding="utf-8") as f:
f.write(binhash)
class BuildCache(ABC):
"""
Abstract base class for build cache implementations
"""
@abstractmethod
def retrieve_build_cache(
self,
cmds: List[str],
build_dir: str,
from_sources_filter_func: Callable[[str], bool] = is_source,
) -> Tuple[bool, Optional[str]]:
"""
Retrieves the build cache artifacts for the given build directory,
so that ideally no compilation needs to take place.
Args:
cmds (_type_): Build commands, these will be part of the hash used to calculate a lookup key
build_dir (str): Build directory. The source files, Makefile and some other files will be hashed and used to
determine the build cache key.
from_sources_filter_func (Callable[[str], bool], optional): Filter function, which may be used to determine which files are being considered source files. Defaults to is_source.
Returns:
Tuple[bool, Optional[str]]: A tuple indicating whether the build cache was successfully retrieved, and a cache key (which should be passed on to store_build_cache on rebuild )
"""
...
@abstractmethod
def store_build_cache(
self,
cmds: List[str],
build_dir: str,
cache_key: str,
filter_func: Callable[[str], bool] = is_cache_artifact,
) -> bool:
"""
Store the build cache artifacts
Args:
cmds ( List[str]): Build commands, these will be part of the hash used to calculate a lookup key
build_dir (str): Path to build directory to retrieve build artifacts from
cache_key (str): Cache key, as returned from retrieve_build_cache
filter_func (Callable[[str], bool], optional): Filter function, which may be used to determine which files are being considered cacheable artifact files. Defaults to is_cache_artifact.
Returns:
bool: Whether the artifacts were successfully stored
"""
...
def maybe_cleanup(
self, lru_retention_hours: int = 72, cleanup_max_age_seconds: int = 3600
):
"""
Maybe clean up the build cache if its been longer than `cleanup_max_age_seconds` that it has been cleaned up
Args:
lru_retention_hours (int, optional): How many hours should unused elements be retained in the cache? Defaults to 72.
cleanup_max_age_seconds (int, optional): Cleanup interval in seconds. Defaults to 3600.
"""
pass
def cleanup(self, retention_hours: int = 72):
"""Do a cache cleanup.
Args:
retention_hours (int, optional): How many hours should unused elements be retained in the cache? Defaults to 72.
"""
pass
def makefile_normalizer(
self, path, memoize_replacements=True, debug=False
) -> Optional[bytes]:
"""
Normalizes the content of the makefile for hashing purposes (nothing else!),
so that it can be compared to other Makefiles
generated by different users on different systems.
"""
p = Path(path)
if not p.name.lower().startswith("makefile"):
return None
makefile_content_orig = p.read_bytes()
target: Target = None
try:
target = Target.current()
except RuntimeError:
# No current target, returning Makefile content unchanged
return makefile_content_orig
if target is None:
return makefile_content_orig
if not hasattr(target, "_compile_options"): #
return makefile_content_orig
if not hasattr(self, "_include_path_hash_cache"):
self._include_path_hash_cache = {}
makefile_content = makefile_content_orig.decode("utf-8")
compile_options = list(shlex.split(target._compile_options))
tmpdir = tempfile.gettempdir()
replacements = {}
for i in range(len(compile_options)):
if compile_options[i] == "-I":
if i < len(compile_options) - 1:
inc_path = compile_options[i + 1]
elif compile_options[i].startswith("-I"):
inc_path = compile_options[i][2:]
else:
continue
# We are creating hashes of all include directories in a temp dir
if inc_path.startswith(tmpdir):
if memoize_replacements and inc_path in self._include_path_hash_cache:
inc_path_hash = self._include_path_hash_cache[inc_path]
else:
inc_path_hash = create_dir_hash([], inc_path, is_source)
if memoize_replacements:
self._include_path_hash_cache[inc_path] = inc_path_hash
replacements[inc_path] = inc_path_hash
for search, replace in replacements.items():
makefile_content = makefile_content.replace(search, replace)
makefile_content = re.sub(
r"[^/\\]+[/\\]fb_include", "fb_include", makefile_content
)
makefile_bytes = makefile_content.encode("utf-8")
if debug:
(p.parent / (p.name + ".normalized")).write_bytes(makefile_bytes)
return makefile_bytes
class NoBuildCache(BuildCache):
def __init__(self):
"""
Dummy build cache implementation which does nothing.
For method docstrings, see parent class.
"""
_LOGGER.info("Build cache disabled")
def retrieve_build_cache(
self,
cmds: List[str],
build_dir: str,
from_sources_filter_func: Callable[[str], bool] = is_source,
) -> Tuple[bool, Optional[str]]:
return False, None
def store_build_cache(
self,
cmds: List[str],
build_dir: str,
cache_key: str,
filter_func: Callable[[str], bool] = is_cache_artifact,
) -> bool:
pass
class FileBasedBuildCache(BuildCache):
def __init__(
self,
cache_dir,
lru_retention_hours=72,
cleanup_max_age_seconds=3600,
debug=True,
):
"""Filesystem based build cache.
For method docstrings, see parent class.
Args:
cache_dir (str): Path to store cache data below. Should be an empty, temporary directory with enough space to hold the cache contents. Will be written to and deleted in!
lru_retention_hours (int, optional): Retention time for *unused* cache entries. Defaults to 72.
cleanup_max_age_seconds (int, optional): Minimum time between cache cleanups in seconds. After this time, a new cleanup gets triggered on next cache retrieval. Defaults to 3600.
debug (bool, optional): Whether to enable debugging cache key creation ( see debug parameter of create_dir_hash). Defaults to True. May be left at True, as it is usually helpful and does not hurt performance.
"""
self.cache_dir = cache_dir
self.lru_retention_hours = lru_retention_hours
self.cleanup_max_age_seconds = cleanup_max_age_seconds
self.debug = debug
_LOGGER.info(
f"Using file-based build cache, cache directory = {self.cache_dir}"
)
def retrieve_build_cache(
self,
cmds: List[str],
build_dir: str,
from_sources_filter_func: Callable[[str], bool] = is_source,
) -> Tuple[bool, Optional[str]]:
"""See docstring of implemented method interface in parent class"""
if should_skip_build_cache():
_LOGGER.info(f"CACHE: Skipped build cache for {build_dir}")
return False, None
self.maybe_cleanup(self.lru_retention_hours, self.cleanup_max_age_seconds)
cache_dir = self.cache_dir
dir_hash = create_dir_hash(
cmds,
build_dir,
filter_func=from_sources_filter_func,
debug=self.debug,
content_replacer=lambda path: self.makefile_normalizer(
path, memoize_replacements=True
),
)
key_cache_dir = os.path.join(cache_dir, dir_hash)
if os.path.exists(key_cache_dir):
_LOGGER.info(f"CACHE: Using cached build results for {build_dir}")
target_basepath = Path(build_dir)
src_basepath = Path(key_cache_dir)
copy_files = [
p.relative_to(src_basepath)
for p in src_basepath.rglob("*")
if not p.is_dir()
]
for filepath in copy_files:
target_path = target_basepath / filepath
target_parent = target_path.parent
src_path = src_basepath / filepath
if target_parent != target_basepath:
os.makedirs(str(target_parent), exist_ok=True)
shutil.copy(
str(src_path),
str(target_path),
follow_symlinks=True,
) # Using shutil.copy intentionally instead of copy2, so the file modification time is updated, and file owner
# is not copied. When you retrieve the file from cache, it is yours.
_LOGGER.debug(f"CACHE: retrieved {filepath}")
# make sure the last modified timestamp is updated, so we can
# evict cache directories which are too old using a separate script
os.utime(key_cache_dir)
return True, dir_hash
_LOGGER.info(f"CACHE: No results found for {build_dir}")
return False, dir_hash
def store_build_cache(
self,
cmds: List[str],
build_dir: str,
cache_key: str,
filter_func: Callable[[str], bool] = is_cache_artifact,
) -> bool:
"""See docstring of implemented method interface in parent class"""
cache_dir = self.cache_dir
key_cache_dir = os.path.join(cache_dir, cache_key)
# We create a temporary directory first, so we can do an
# atomic update later to prevent race conditions
# in a distributed / parallel build setting
random_str = secrets.token_hex(16)
# the temp_cache_dir will be renamed to key_cache_dir
# atomically later. It needs to be on same file system
# for atomic rename, so we put it into the same folder.
temp_cache_dir = key_cache_dir + f".{random_str}.tmp"
try:
os.makedirs(temp_cache_dir, exist_ok=False)
except OSError:
_LOGGER.warn(
f"CACHE: Failed to create tempdir {temp_cache_dir}. Cannot write cache entries."
)
return False
basepath = Path(build_dir)
target_basepath = Path(temp_cache_dir)
copy_files = [
p.relative_to(basepath) for p in basepath.rglob("*") if not p.is_dir()
]
for filepath in copy_files:
src_path = basepath / filepath
if not filter_func(str(filepath)):
continue
target_path = target_basepath / filepath
target_parent = target_path.parent
if target_parent != target_basepath:
os.makedirs(str(target_parent), exist_ok=True)
shutil.copy2(
str(src_path),
str(target_path),
follow_symlinks=True,
) # Use copy2, so the file metadata (incl. last modified time) is preserved
_LOGGER.info(f"CACHE: storing {filepath} into {key_cache_dir}: ")
try:
os.rename(
temp_cache_dir, key_cache_dir
) # Atomic update to prevent race condition
return True
except OSError:
_LOGGER.info(
f"CACHE: update race conflict - {key_cache_dir} already exists. (Note: No error! This can be expected to happen occasionally.))"
)
shutil.rmtree(temp_cache_dir, ignore_errors=True)
return False
def maybe_cleanup(
self, lru_retention_hours: int = 72, cleanup_max_age_seconds: int = 3600
):
"""See docstring of implemented method interface in parent class"""
last_cleaned_seconds = file_age(os.path.join(self.cache_dir, ".last_cleaned"))
if last_cleaned_seconds > cleanup_max_age_seconds:
self.cleanup(lru_retention_hours)
def cleanup(self, lru_retention_hours: int = 72):
"""See docstring of implemented method interface in parent class"""
_LOGGER.info(
f"CACHE: Cleaning up build cache below {self.cache_dir}. Folders last used more than {lru_retention_hours} hours ago will be deleted."
)
touch(os.path.join(self.cache_dir, ".last_cleaned"))
if os.path.isdir(self.cache_dir):
now = datetime.now()
age_limit = timedelta(hours=lru_retention_hours)
for dirpath in os.scandir(self.cache_dir):
if os.path.isdir(dirpath):
# Get the modification time of the directory and convert it to a datetime object
mtime = os.path.getmtime(dirpath)
modification_time = datetime.fromtimestamp(mtime)
# Check if the directory is older than N hours
if now - modification_time > age_limit:
_LOGGER.info(f"CACHE: Deleting {dirpath}")
shutil.rmtree(dirpath)