Source code for aitemplate.compiler.symbolic

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
"""
Symbolic helpers for AITemplate.
AITemplate leverages Sympy to do symbolic computations for shapes.
The core of Sympy is surrounded around the class "Symbol". We could apply operations
on Symbols (i.e. add/mul/power/etc.) Which could help us do basic arithmetic with
unknown values.
The symbolic-ness comes from representation that includes Symbol (i.e. sym_1 + 100.)

Example Usage:
A = IntVar(...)
sym_A = A.symbolic_value() # equivalent of A._attrs["symbolic_value"]

# do something about sym_A, some common usage include:
new_sym = sym_A + 100
new_sym = sym_A - 100
new_sym = sym_A * 2
new_sym = sym_A * sym_B

# We could then assign the symbolic value to a new IntVar.
new_var = IntVar(..., symbolic_value=new_sym)

For more advanced usage on Sympy, check: https://docs.sympy.org/latest/tutorials/intro-tutorial/intro.html
"""
from __future__ import annotations

import itertools

from numbers import Number
from typing import Any, List, Optional, Set

import sympy


_k_symbolic_to_intvar = {}
_k_symbolic_index = 0
_k_symbolic_value = {}


def create_new_symbol(
    name: Optional[str] = None,
    values: Optional[List[int]] = None,
    check_duplicate: bool = False,
) -> sympy.Symbol:
    """
    Creates and memoizing symbols.

    Parameters
    ----------
    name : Optional[str]
        The symbol name that is going to be used. If None is provided, an unused
        name would be created.
    values : Optional[List[int]]
        The values for IntVar, which indicates the range of which the symbol could
        represent.
    check_duplicate : bool
        If set as True and name is provided, we check whether the name and values
        provided matches the corresponding symbol recorded.
    """
    global _k_symbolic_index
    global _k_symbolic_value

    if name is None:
        while True:
            name = f"_sym_{_k_symbolic_index}"
            _k_symbolic_index += 1

            if name not in _k_symbolic_value:
                break

    values = sorted(set(values)) if values is not None else values
    if (
        check_duplicate
        and name in _k_symbolic_value
        and _k_symbolic_value[name] != values
    ):
        raise ValueError(
            f"Symbol ({name}) has different values! New value is {values}, stored value is {_k_symbolic_value[name]}"
        )

    _k_symbolic_value[name] = values
    return sympy.Symbol(name)


def is_symbol(sym_val: Any) -> bool:
    return isinstance(sym_val, sympy.Symbol)


[docs]def is_symbolic(sym_val: Any) -> bool: """ Check whether sym_val is a sympy class. """ return isinstance(sym_val, sympy.Basic)
def is_integer(sym_val: Any) -> bool: # We wrap this since None is returned if sympy can't determine the property. if is_symbolic(sym_val): return sym_val.is_number and int(sym_val) - sym_val == 0 elif isinstance(sym_val, Number): return int(sym_val) - sym_val == 0 return False def get_global_symbol_set() -> Set: global _k_symbolic_value return set(_k_symbolic_value.keys()) def get_intvar(sym_name: str): global _k_symbolic_to_intvar return _k_symbolic_to_intvar.get(sym_name, None) def store_intvar(sym_name: str, int_var) -> None: global _k_symbolic_to_intvar _k_symbolic_to_intvar[sym_name] = int_var
[docs]def simplify_intvar_values(sym_val: sympy.Basic): """ Given a symbolic value, resolve the symbol's value range. Example: 'symbol_A' has value range of [10, 20] simplify_intvar_values(symbol_A * 3 + 4) returns [34, 64] """ global _k_symbolic_value symbols = list(sym_val.free_symbols) symbol_shapes = [_k_symbolic_value[s.name] for s in symbols] symbol_shapes = [s for s in symbol_shapes if s is not None] shape_perms = list(itertools.product(*symbol_shapes)) new_shape = [int(sym_val.subs(zip(symbols, s))) for s in shape_perms] new_shape = sorted(set(new_shape)) return new_shape