Source code for aitemplate.backend.profiler_cache

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
SQLite backend for conv/gemm profiling cache
"""
import enum
import logging
import sqlite3

from typing import Any, Dict, Tuple

import jinja2

# pylint: disable=W0613


_LOGGER = logging.getLogger(__name__)


[docs]class CacheMode(enum.Enum): r"""Enum for cache mode Profiling cache can be stored locally or remotely. For LOCAL mode, the cache is stored in a SQLite database. For REMOTE mode, the profiled results can be queried with RESTFul API. REMOTE mode is not implemented yet. """ LOCAL = 1 REMOTE = 2
GEMM_INIT_TEMPLATE = jinja2.Template( """ CREATE TABLE IF NOT EXISTS {{dev}}_gemm_{{version}} ( id INTEGER PRIMARY KEY AUTOINCREMENT, exec_entry VARCHAR(8192) NOT NULL, exec_entry_sha1 VARCHAR(64) NOT NULL, dtype_a INTEGER NOT NULL, dtype_b INTEGER NOT NULL, dtype_c INTEGER NOT NULL, dtype_acc INTEGER NOT NULL, major_a INTEGER NOT NULL, major_b INTEGER NOT NULL, major_c INTEGER NOT NULL, op_type VARCHAR(512) NOT NULL, epilogue VARCHAR(512) NOT NULL, device VARCHAR(16) NOT NULL, algo VARCHAR(512) NOT NULL, workspace INTEGER DEFAULT 0, duration FLOAT DEFAULT -1, split_k INTEGER DEFAULT 1, pshape VARCHAR(64) NOT NULL, template_ver INTEGER NOT NULL DEFAULT 290, created_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL ); """ ) GEMM_QUERY_TEMPLATE = jinja2.Template( """ SELECT algo, workspace, split_k FROM {{dev}}_gemm_{{version}} WHERE dtype_a={{dtype_a}} AND dtype_b={{dtype_b}} AND dtype_c={{dtype_c}} AND dtype_acc={{dtype_acc}} AND major_a={{major_a}} AND major_b={{major_b}} AND major_c={{major_c}} AND op_type='{{op_type}}' AND device='{{device}}' AND epilogue={{epilogue}} AND pshape='{{pshape}}' AND exec_entry_sha1='{{exec_entry_sha1}}'; """ ) GEMM_INSERT_TEMPLATE = jinja2.Template( """ INSERT INTO {{dev}}_gemm_{{version}} ( exec_entry, exec_entry_sha1, dtype_a, dtype_b, dtype_c, dtype_acc, major_a, major_b, major_c, op_type, epilogue, device, algo, workspace, split_k, pshape ) VALUES ( '{{exec_entry}}', '{{exec_entry_sha1}}', {{dtype_a}}, {{dtype_b}}, {{dtype_c}}, {{dtype_acc}}, {{major_a}}, {{major_b}}, {{major_c}}, '{{op_type}}', {{epilogue}}, '{{device}}', '{{algo}}', {{workspace}}, {{split_k}}, '{{pshape}}' ); """ ) CONV_INIT_TEMPLATE = jinja2.Template( """ CREATE TABLE IF NOT EXISTS {{dev}}_conv_{{version}} ( id INTEGER PRIMARY KEY AUTOINCREMENT, exec_entry VARCHAR(8192) NOT NULL, exec_entry_sha1 VARCHAR(64) NOT NULL, dtype_a INTEGER NOT NULL, dtype_b INTEGER NOT NULL, dtype_c INTEGER NOT NULL, dtype_acc INTEGER NOT NULL, major_a INTEGER NOT NULL, major_b INTEGER NOT NULL, major_c INTEGER NOT NULL, kh INTEGER NOT NULL, kw INTEGER NOT NULL, co INTEGER NOT NULL, strideh INTEGER NOT NULL, padh INTEGER NOT NULL, dilateh INTEGER NOT NULL, stridew INTEGER NOT NULL, padw INTEGER NOT NULL, dilatew INTEGER NOT NULL, op_type VARCHAR(512) NOT NULL, epilogue VARCHAR(512) NOT NULL, device VARCHAR(16) NOT NULL, algo VARCHAR(512) NOT NULL, workspace INTEGER DEFAULT 0, duration FLOAT DEFAULT -1, split_k INTEGER DEFAULT 1, template_ver INTEGER NOT NULL DEFAULT 290, created_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL ); """ ) CONV_QUERY_TEMPLATE = jinja2.Template( """ SELECT algo, workspace FROM {{dev}}_conv_{{version}} WHERE dtype_a={{dtype_a}} AND dtype_b={{dtype_b}} AND dtype_c={{dtype_c}} AND dtype_acc={{dtype_acc}} AND major_a={{major_a}} AND major_b={{major_b}} AND major_c={{major_c}} AND kh={{kh}} AND kw={{kw}} AND co={{co}} AND strideh={{strideh}} AND padh={{padh}} AND dilateh={{dilateh}} AND stridew={{stridew}} AND padw={{padw}} AND dilatew={{dilatew}} AND op_type='{{op_type}}' AND device='{{device}}' AND epilogue={{epilogue}} AND split_k={{split_k}} AND exec_entry_sha1='{{exec_entry_sha1}}'; """ ) CONV_INSERT_TEMPLATE = jinja2.Template( """ INSERT INTO {{dev}}_conv_{{version}} ( exec_entry, exec_entry_sha1, dtype_a, dtype_b, dtype_c, dtype_acc, major_a, major_b, major_c, kh, kw, co, strideh, padh, dilateh, stridew, padw, dilatew, op_type, epilogue, device, algo, workspace, split_k ) VALUES ( '{{exec_entry}}', '{{exec_entry_sha1}}', {{dtype_a}}, {{dtype_b}}, {{dtype_c}}, {{dtype_acc}}, {{major_a}}, {{major_b}}, {{major_c}}, {{kh}}, {{kw}}, {{co}}, {{strideh}}, {{padh}}, {{dilateh}}, {{stridew}}, {{padw}}, {{dilatew}}, '{{op_type}}', {{epilogue}}, '{{device}}', '{{algo}}', {{workspace}}, {{split_k}} ); """ ) CONV3D_INIT_TEMPLATE = jinja2.Template( """ CREATE TABLE IF NOT EXISTS {{dev}}_conv3d_{{version}} ( id INTEGER PRIMARY KEY AUTOINCREMENT, exec_entry VARCHAR(8192) NOT NULL, exec_entry_sha1 VARCHAR(64) NOT NULL, dtype_a INTEGER NOT NULL, dtype_b INTEGER NOT NULL, dtype_c INTEGER NOT NULL, dtype_acc INTEGER NOT NULL, major_a INTEGER NOT NULL, major_b INTEGER NOT NULL, major_c INTEGER NOT NULL, kd INTEGER NOT NULL, kh INTEGER NOT NULL, kw INTEGER NOT NULL, co INTEGER NOT NULL, stride_d INTEGER NOT NULL, stride_h INTEGER NOT NULL, stride_w INTEGER NOT NULL, pad_d INTEGER NOT NULL, pad_h INTEGER NOT NULL, pad_w INTEGER NOT NULL, dilate_d INTEGER NOT NULL, dilate_h INTEGER NOT NULL, dilate_w INTEGER NOT NULL, op_type VARCHAR(512) NOT NULL, epilogue VARCHAR(512) NOT NULL, device VARCHAR(16) NOT NULL, algo VARCHAR(512) NOT NULL, workspace INTEGER DEFAULT 0, duration FLOAT DEFAULT -1, split_k INTEGER DEFAULT 1, template_ver INTEGER NOT NULL DEFAULT 290, created_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL ); """ ) CONV3D_QUERY_TEMPLATE = jinja2.Template( """ SELECT algo, workspace FROM {{dev}}_conv3d_{{version}} WHERE dtype_a={{dtype_a}} AND dtype_b={{dtype_b}} AND dtype_c={{dtype_c}} AND dtype_acc={{dtype_acc}} AND major_a={{major_a}} AND major_b={{major_b}} AND major_c={{major_c}} AND kd={{kd}} AND kh={{kh}} AND kw={{kw}} AND co={{co}} AND stride_d={{stride_d}} AND stride_h={{stride_h}} AND stride_w={{stride_w}} AND pad_d={{pad_d}} AND pad_h={{pad_h}} AND pad_w={{pad_w}} AND dilate_d={{dilate_d}} AND dilate_h={{dilate_h}} AND dilate_w={{dilate_w}} AND op_type='{{op_type}}' AND device='{{device}}' AND epilogue={{epilogue}} AND split_k={{split_k}} AND exec_entry_sha1='{{exec_entry_sha1}}'; """ ) CONV3D_INSERT_TEMPLATE = jinja2.Template( """ INSERT INTO {{dev}}_conv3d_{{version}} ( exec_entry, exec_entry_sha1, dtype_a, dtype_b, dtype_c, dtype_acc, major_a, major_b, major_c, kd, kh, kw, co, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, dilate_d, dilate_h, dilate_w, op_type, epilogue, device, algo, workspace, split_k ) VALUES ( '{{exec_entry}}', '{{exec_entry_sha1}}', {{dtype_a}}, {{dtype_b}}, {{dtype_c}}, {{dtype_acc}}, {{major_a}}, {{major_b}}, {{major_c}}, {{kd}}, {{kh}}, {{kw}}, {{co}}, {{stride_d}}, {{stride_h}}, {{stride_w}}, {{pad_d}}, {{pad_h}}, {{pad_w}}, {{dilate_d}}, {{dilate_h}}, {{dilate_w}}, '{{op_type}}', {{epilogue}}, '{{device}}', '{{algo}}', {{workspace}}, {{split_k}} ); """ ) NORM_INIT_TEMPLATE = jinja2.Template( """ CREATE TABLE IF NOT EXISTS {{dev}}_normalization ( id INTEGER PRIMARY KEY AUTOINCREMENT, exec_entry VARCHAR(8192) NOT NULL, exec_entry_sha1 VARCHAR(64) NOT NULL, dtype_in INTEGER NOT NULL, dtype_acc INTEGER NOT NULL, dtype_out INTEGER NOT NULL, rank INTEGER NOT NULL, op_type VARCHAR(512) NOT NULL, device VARCHAR(16) NOT NULL, algo VARCHAR(512) NOT NULL, workspace INTEGER DEFAULT 0, duration FLOAT DEFAULT -1, template_ver INTEGER NOT NULL DEFAULT 290, created_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL ); """ ) NORM_QUERY_TEMPLATE = jinja2.Template( """ SELECT algo, workspace FROM {{dev}}_normalization WHERE dtype_in={{dtype_in}} AND dtype_out={{dtype_out}} AND dtype_acc={{dtype_acc}} AND rank={{rank}} AND op_type='{{op_type}}' AND device='{{device}}' AND exec_entry_sha1='{{exec_entry_sha1}}'; """ ) NORM_INSERT_TEMPLATE = jinja2.Template( """ INSERT INTO {{dev}}_normalization ( exec_entry, exec_entry_sha1, dtype_in, dtype_out, dtype_acc, rank, op_type, device, algo, workspace ) VALUES ( '{{exec_entry}}', '{{exec_entry_sha1}}', {{dtype_in}}, {{dtype_out}}, {{dtype_acc}}, {{rank}}, '{{op_type}}', '{{device}}', '{{algo}}', {{workspace}} ); """ ) CHECK_TABLE_EXISTENCE_TEMPLATE = jinja2.Template( """ SELECT name FROM sqlite_master WHERE type='table' AND name='{{table_name}}'; """ ) QUERY_ALL_TABLES_TEMPLATE = jinja2.Template( """ SELECT name FROM sqlite_master WHERE type='table'; """ ) __AIT_CACHE_VERSION__ = 3 def ait_cache_version() -> int: return __AIT_CACHE_VERSION__
[docs]class ProfileCacheDB: r"""Local SQLite profile cache database.""" def __init__( self, target: str, path: str = None, uri: str = None, port: str = None ): r"""= Parameters ---------- target : str target device name. CUDA or ROCM path : str, optional path to the database file. If not specified, a temporary file is created. uri : str, optional uri to the RESFul API (Not implemented yet) port : str, optional port to the RESFul API (Not implemented yet) """ self._target = target self._mode = CacheMode.LOCAL self._db_commit_flag = False # Some design rationales: # * All tables share the version number, because we are exposing the # the cache version. Using a single version number seems to make it # more clean. One caveat is that we are going to re-create all tables # even if some of them are not involved with the breaking changes. # It seems to be fine as we expect the frequency of cache version # updated to be quite low. # * We only keep a single table (i.e. version) for each category ( # gemm, conv and norm) to simplify how we handle breaking changes # and rollbacks caused by failures in the updated version. # For example, if we keep multiple versions (i.e. tables) for gemm, # we would have to consider how we were going to maintain those versions. # We could choose the old working version upon rollback, but we might # leave some content from the failing version in the db. How are we # going to update the db if we update the version again, and so on. # TODO: add similar version control for norm self._gemm_cache_version = ait_cache_version() self._conv_cache_version = ait_cache_version() self._conv3d_cache_version = ait_cache_version() if uri is not None: self._mode = CacheMode.REMOTE if self._mode == CacheMode.LOCAL: assert path is not None self._con = sqlite3.connect(path) self._cur = self._con.cursor() self._init_db() else: raise NotImplementedError def _init_db(self): """Creates table in cache.""" self._create_gemm_table() self._create_conv_table() self._create_conv3d_table() self._create_norm_table() @property def gemm_cache_version(self) -> int: return self._gemm_cache_version @property def conv_cache_version(self) -> int: return self._conv_cache_version @property def conv3d_cache_version(self) -> int: return self._conv3d_cache_version def _create_gemm_table(self): """Creates gemm table.""" version = self.gemm_cache_version if not self._table_exists("gemm", version): _LOGGER.info( "Temporarily keeping the old gemm cache versions if exist", ) # FIXME: will delete unmatched version once we get into production # self._delete_existing_table("gemm") _LOGGER.info( f"Creating a new gemm table with {version=}", ) sql = GEMM_INIT_TEMPLATE.render( dev=self._target, version=version, ) self._cur.execute(sql) self._con.commit() def _create_conv_table(self): """Creates conv table.""" version = self.conv_cache_version if not self._table_exists("conv", version): _LOGGER.info( "Temporarily keeping the old conv cache versions if exist", ) # FIXME: will delete unmatched version once we get into production # self._delete_existing_table("conv") _LOGGER.info( f"Creating a new conv table with {version=}", ) sql = CONV_INIT_TEMPLATE.render( dev=self._target, version=version, ) self._cur.execute(sql) self._con.commit() def _create_conv3d_table(self): """Creates conv3d table.""" version = self.conv3d_cache_version if not self._table_exists("conv3d", version): _LOGGER.info( "Temporarily keeping the old conv3d cache versions if exist", ) # FIXME: will delete unmatched version once we get into production # self._delete_existing_table("conv3d") _LOGGER.info( f"Creating a new conv3d table with {version=}", ) sql = CONV3D_INIT_TEMPLATE.render( dev=self._target, version=version, ) self._cur.execute(sql) self._con.commit() def _create_norm_table(self): """Creates conv table.""" sql = NORM_INIT_TEMPLATE.render(dev=self._target) self._cur.execute(sql) self._con.commit() def _table_exists(self, table_kind, cache_version): """Check if the table of given kind and cache version exists.""" table_name = f"{self._target}_{table_kind}_{cache_version}" sql = CHECK_TABLE_EXISTENCE_TEMPLATE.render(table_name=table_name) self._cur.execute(sql) tables = self._cur.fetchall() if tables: _LOGGER.info( f"{table_name=} exists in the db", ) return True else: _LOGGER.info( f"{table_name=} does not exist in the db, possible version mismatch!", ) return False def _delete_existing_table(self, table_kind): """Delete an existing table in the db""" sql = QUERY_ALL_TABLES_TEMPLATE.render() self._cur.execute(sql) all_tables = self._cur.fetchall() if len(all_tables) == 0: _LOGGER.info("deleting table: skip empty table") return target_tables = [ table[0] for table in all_tables if table[0].startswith(f"{self._target}_{table_kind}") ] assert len(target_tables) != 0, f"no {table_kind} table exists" # To simplify the logic, we only keep a single table for each kind assert ( len(target_tables) == 1 ), f"expected only one {table_kind} table but got {target_tables=}" _LOGGER.info(f"deleting table {target_tables[0]=}") self._cur.execute(f"DROP TABLE {target_tables[0]}") def _query(self, sql: str) -> Tuple[str, int]: """a function to query op from cache Parameters ---------- sql : str sql statement for query Returns ------- Tuple profiling results """ if self._mode == CacheMode.LOCAL: if self._db_commit_flag: self._con.commit() self._db_commit_flag = False self._cur.execute(sql) out = self._cur.fetchall() if len(out) == 0: return None return out[0] raise NotImplementedError
[docs] def query_gemm(self, args: Dict[str, Any]) -> Tuple[str, int]: """a function to query gemm op epilogue from cache Parameters ---------- args : Dict gemm query entry Returns ------- Tuple profiling results """ sql = GEMM_QUERY_TEMPLATE.render( dev=self._target, version=self.gemm_cache_version, **args, ) return self._query(sql)
[docs] def query_conv(self, args: Dict[str, Any]) -> Tuple[str, int]: """a function to query conv op epilogue from cache, here we use the same sql table for conv and gemm Parameters ---------- args : Dict Conv query entry Returns ------- Tuple profiling results """ sql = CONV_QUERY_TEMPLATE.render( dev=self._target, version=self.conv_cache_version, **args, ) return self._query(sql)
[docs] def query_conv3d(self, args: Dict[str, Any]) -> Tuple[str, int]: """a function to query conv op epilogue from cache, here we use the same sql table for conv and gemm Parameters ---------- args : Dict Conv3d query entry Returns ------- Tuple profiling results """ sql = CONV3D_QUERY_TEMPLATE.render( dev=self._target, version=self.conv3d_cache_version, **args, ) return self._query(sql)
[docs] def query_normalization(self, args: Dict[str, Any]) -> Tuple[str, int]: """a function to query normalization op epilogue from cache Parameters ---------- args : Dict Conv query entry Returns ------- Tuple profiling results """ sql = NORM_QUERY_TEMPLATE.render(dev=self._target, **args) return self._query(sql)
def _insert(self, query_sql: str, insert_sql: str) -> None: """a function to insert op into cache Parameters ---------- query_sql : str cache query sql insert_sql: str cache insert sql """ if self._mode == CacheMode.LOCAL: self._cur.execute(query_sql) out = self._cur.fetchall() if len(out) == 0: self._cur.execute(insert_sql) self._db_commit_flag = True else: _LOGGER.info("Ignore repeat profile_record: " + query_sql)
[docs] def insert_gemm(self, args: Dict[str, Any]) -> None: """a function to insert gemm op epilogue into cache Parameters ---------- args : Dict Gemm Record Entry """ query_sql = GEMM_QUERY_TEMPLATE.render( dev=self._target, version=self.gemm_cache_version, dtype_a=args["dtype_a"], dtype_b=args["dtype_b"], dtype_c=args["dtype_c"], dtype_acc=args["dtype_acc"], major_a=args["major_a"], major_b=args["major_b"], major_c=args["major_c"], op_type=args["op_type"], device=args["device"], epilogue=args["epilogue"], split_k=args["split_k"], pshape=args["pshape"], exec_entry_sha1=args["exec_entry_sha1"], ) insert_sql = GEMM_INSERT_TEMPLATE.render( dev=self._target, version=self.gemm_cache_version, **args, ) self._insert(query_sql, insert_sql)
[docs] def insert_conv(self, args: Dict[str, Any]) -> None: """a function to insert conv op epilogue into cache, here we use the same sql table for conv and gemm Parameters ---------- args : Dict Conv Record Entry """ query_sql = CONV_QUERY_TEMPLATE.render( dev=self._target, version=self.conv_cache_version, dtype_a=args["dtype_a"], dtype_b=args["dtype_b"], dtype_c=args["dtype_c"], dtype_acc=args["dtype_acc"], major_a=args["major_a"], major_b=args["major_b"], major_c=args["major_c"], kh=args["kh"], kw=args["kw"], co=args["co"], strideh=args["strideh"], padh=args["padh"], dilateh=args["dilateh"], stridew=args["stridew"], padw=args["padw"], dilatew=args["dilatew"], op_type=args["op_type"], device=args["device"], epilogue=args["epilogue"], split_k=args["split_k"], exec_entry_sha1=args["exec_entry_sha1"], ) insert_sql = CONV_INSERT_TEMPLATE.render( dev=self._target, version=self.conv_cache_version, **args, ) self._insert(query_sql, insert_sql)
[docs] def insert_conv3d(self, args: Dict[str, Any]) -> None: """a function to insert conv op epilogue into cache, here we use the same sql table for conv and gemm Parameters ---------- args : Dict Conv Record Entry """ query_sql = CONV3D_QUERY_TEMPLATE.render( dev=self._target, version=self.conv3d_cache_version, dtype_a=args["dtype_a"], dtype_b=args["dtype_b"], dtype_c=args["dtype_c"], dtype_acc=args["dtype_acc"], major_a=args["major_a"], major_b=args["major_b"], major_c=args["major_c"], kd=args["kd"], kh=args["kh"], kw=args["kw"], co=args["co"], stride_d=args["stride_d"], stride_h=args["stride_h"], stride_w=args["stride_w"], pad_d=args["pad_d"], pad_h=args["pad_h"], pad_w=args["pad_w"], dilate_d=args["dilate_d"], dilate_h=args["dilate_h"], dilate_w=args["dilate_w"], op_type=args["op_type"], device=args["device"], epilogue=args["epilogue"], split_k=args["split_k"], exec_entry_sha1=args["exec_entry_sha1"], ) insert_sql = CONV3D_INSERT_TEMPLATE.render( dev=self._target, version=self.conv3d_cache_version, **args, ) self._insert(query_sql, insert_sql)
[docs] def insert_normalization(self, args: Dict[str, Any]) -> None: """a function to insert conv op epilogue into cache, here we use the same sql table for conv and gemm Parameters ---------- args : Dict Conv Record Entry """ query_sql = NORM_QUERY_TEMPLATE.render( dev=self._target, dtype_in=args["dtype_in"], dtype_acc=args["dtype_acc"], dtype_out=args["dtype_out"], rank=args["rank"], op_type=args["op_type"], device=args["device"], exec_entry_sha1=args["exec_entry_sha1"], ) insert_sql = NORM_INSERT_TEMPLATE.render(dev=self._target, **args) self._insert(query_sql, insert_sql)
def __del__(self): self._con.commit() self._con.close()