Source code for aitemplate.frontend.nn.module

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
from collections import namedtuple, OrderedDict
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union

from aitemplate.compiler.base import Tensor
from aitemplate.frontend.nn.parameter import Parameter


class _IncompatibleKeys(
    namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])
):
    def __repr__(self):
        if not self.missing_keys and not self.unexpected_keys:
            return "<All keys matched successfully>"
        return super(_IncompatibleKeys, self).__repr__()

    __str__ = __repr__


# Trick mypy into not applying contravariance rules to inputs by defining
# forward as a value, rather than a function.  See also
# https://github.com/python/mypy/issues/8795
def _forward_unimplemented(self, *input: Any) -> None:
    r"""Defines the computation performed at every call.

    Should be overridden by all subclasses.

    .. note::
        Although the recipe for forward pass needs to be defined within
        this function, one should call the :class:`Module` instance afterwards
        instead of this since the former takes care of running the
        registered hooks while the latter silently ignores them.
    """
    raise NotImplementedError(
        f'Module [{type(self).__name__}] is missing the required "forward" function'
    )


def typename(x):
    if hasattr(x, "__class__"):
        return x.__class__.__name__
    else:
        return str(type(x))


def _addindent(s_, numSpaces):
    s = s_.split("\n")
    # don't do anything for single-line stuff
    if len(s) == 1:
        return s_
    first = s.pop(0)
    s = [(numSpaces * " ") + line for line in s]
    s = "\n".join(s)
    s = first + "\n" + s
    return s


[docs]class Module: r"""Base class for all neural network modules. Your models should also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:: import nn as nn import nn.functional as F class Model(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:`to`, etc. .. note:: As per the example above, an ``__init__()`` call to the parent class must be made before assignment on the child. :ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool """ dump_patches: bool = False _version: int = 1 r"""This allows better BC support for :meth:`load_state_dict`. In :meth:`state_dict`, the version number will be saved as in the attribute `_metadata` of the returned state dict, and thus pickled. `_metadata` is a dictionary with keys that follow the naming convention of state dict. See ``_load_from_state_dict`` on how to use this information in loading. If new parameters/buffers are added/removed from a module, this number shall be bumped, and the module's `_load_from_state_dict` method can compare the version number and do appropriate changes if the state dict is from before the change.""" _parameters: Dict[str, Optional[Parameter]] _buffers: Dict[str, Optional[Tensor]] _modules: Dict[str, Optional["Module"]] def __init__(self) -> None: """ Calls super().__setattr__('a', a) instead of the typical self.a = a to avoid Module.__setattr__ overhead. Module's __setattr__ has special handling for parameters, submodules, and buffers but simply calls into super().__setattr__ for all other attributes. """ super().__setattr__("_parameters", OrderedDict()) super().__setattr__("_buffers", OrderedDict()) super().__setattr__("_modules", OrderedDict()) forward: Callable[..., Any] = _forward_unimplemented
[docs] def register_buffer( self, name: str, tensor: Optional[Tensor], persistent: bool = True ) -> None: r"""Adds a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:`persistent` to ``False``. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:`state_dict`. Buffers can be accessed as attributes using given names. Args: name (str): name of the buffer. The buffer can be accessed from this module using the given name tensor (Tensor or None): buffer to be registered. If ``None``, then operations that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, the buffer is **not** included in the module's :attr:`state_dict`. persistent (bool): whether the buffer is part of this module's :attr:`state_dict`. Example:: >>> self.register_buffer('running_mean', zeros(num_features)) """ if "_buffers" not in self.__dict__: raise AttributeError("cannot assign buffer before Module.__init__() call") elif "." in name: raise KeyError('buffer name can\'t contain "."') elif name == "": raise KeyError('buffer name can\'t be empty string ""') elif hasattr(self, name) and name not in self._buffers: raise KeyError("attribute '{}' already exists".format(name)) elif tensor is not None and not isinstance(tensor, Tensor): raise TypeError( "cannot assign '{}' object to buffer '{}' " "(torch Tensor or None required)".format(typename(tensor), name) ) else: self._buffers[name] = tensor if persistent: self._non_persistent_buffers_set.discard(name) else: self._non_persistent_buffers_set.add(name)
[docs] def register_parameter(self, name: str, param: Optional[Parameter]) -> None: r"""Adds a parameter to the module. The parameter can be accessed as an attribute using given name. Args: name (str): name of the parameter. The parameter can be accessed from this module using the given name param (Parameter or None): parameter to be added to the module. If ``None``, then operations that run on parameters, such as :attr:`cuda`, are ignored. If ``None``, the parameter is **not** included in the module's :attr:`state_dict`. """ if "_parameters" not in self.__dict__: raise AttributeError( "cannot assign parameter before Module.__init__() call" ) elif "." in name: raise KeyError('parameter name can\'t contain "."') elif name == "": raise KeyError('parameter name can\'t be empty string ""') elif hasattr(self, name) and name not in self._parameters: raise KeyError("attribute '{}' already exists".format(name)) if param is None: self._parameters[name] = None elif not isinstance(param, Parameter): raise TypeError( "cannot assign '{}' object to parameter '{}' " "(nn.Parameter or None required)".format(typename(param), name) ) else: self._parameters[name] = param
[docs] def add_module(self, name: str, module: Optional["Module"]) -> None: r"""Adds a child module to the current module. The module can be accessed as an attribute using the given name. Args: name (str): name of the child module. The child module can be accessed from this module using the given name module (Module): child module to be added to the module. """ if not isinstance(module, Module) and module is not None: raise TypeError("{} is not a Module subclass".format(typename(module))) elif hasattr(self, name) and name not in self._modules: raise KeyError("attribute '{}' already exists".format(name)) elif "." in name: raise KeyError('module name can\'t contain ".", got: {}'.format(name)) elif name == "": raise KeyError('module name can\'t be empty string ""') self._modules[name] = module
[docs] def register_module(self, name: str, module: Optional["Module"]) -> None: r"""Alias for :func:`add_module`.""" self.add_module(name, module)
[docs] def get_submodule(self, target: str) -> "Module": """ Returns the submodule given by ``target`` if it exists, otherwise throws an error. For example, let's say you have an ``nn.Module`` ``A`` that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested submodule ``net_b``, which itself has two submodules ``net_c`` and ``linear``. ``net_c`` then has a submodule ``conv``.) To check whether or not we have the ``linear`` submodule, we would call ``get_submodule("net_b.linear")``. To check whether we have the ``conv`` submodule, we would call ``get_submodule("net_b.net_c.conv")``. The runtime of ``get_submodule`` is bounded by the degree of module nesting in ``target``. A query against ``named_modules`` achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, ``get_submodule`` should always be used. Args: target: The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) Returns: nn.Module: The submodule referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Module`` """ if target == "": return self atoms: List[str] = target.split(".") mod: Module = self for item in atoms: if not hasattr(mod, item): raise AttributeError( mod._get_name() + " has no " "attribute `" + item + "`" ) mod = getattr(mod, item) if not isinstance(mod, Module): raise AttributeError("`" + item + "` is not " "an nn.Module") return mod
[docs] def get_parameter(self, target: str) -> "Parameter": """ Returns the parameter given by ``target`` if it exists, otherwise throws an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the Parameter to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: nn.Parameter: The Parameter referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Parameter`` """ module_path, _, param_name = target.rpartition(".") mod: Module = self.get_submodule(module_path) if not hasattr(mod, param_name): raise AttributeError( mod._get_name() + " has no attribute `" + param_name + "`" ) param: Parameter = getattr(mod, param_name) if not isinstance(param, Parameter): raise AttributeError("`" + param_name + "` is not an " "nn.Parameter") return param
[docs] def get_buffer(self, target: str) -> "Tensor": """ Returns the buffer given by ``target`` if it exists, otherwise throws an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the buffer to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: Tensor: The buffer referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not a buffer """ module_path, _, buffer_name = target.rpartition(".") mod: Module = self.get_submodule(module_path) if not hasattr(mod, buffer_name): raise AttributeError( mod._get_name() + " has no attribute `" + buffer_name + "`" ) buffer: Tensor = getattr(mod, buffer_name) if buffer_name not in mod._buffers: raise AttributeError("`" + buffer_name + "` is not a buffer") return buffer
def _call_impl(self, *input, **kwargs): forward_call = self.forward return forward_call(*input, **kwargs) __call__: Callable[..., Any] = _call_impl def __getattr__(self, name: str) -> Union[Tensor, "Module"]: if "_parameters" in self.__dict__: _parameters = self.__dict__["_parameters"] if name in _parameters: return _parameters[name] if "_buffers" in self.__dict__: _buffers = self.__dict__["_buffers"] if name in _buffers: return _buffers[name] if "_modules" in self.__dict__: modules = self.__dict__["_modules"] if name in modules: return modules[name] raise AttributeError( "'{}' object has no attribute '{}'".format(type(self).__name__, name) ) def __setattr__(self, name: str, value: Union[Tensor, "Module"]) -> None: def try_remove(): dicts_or_sets = [ self.__dict__, self.__dict__.get("_parameters"), self.__dict__.get("_buffers"), self.__dict__.get("_modules"), ] for d in dicts_or_sets: if name in d: if isinstance(d, dict): d.pop(name, None) else: d.discard(name) params = self.__dict__.get("_parameters") if isinstance(value, Parameter): if params is None: raise AttributeError( "cannot assign parameters before Module.__init__() call" ) try_remove() self.register_parameter(name, value) elif params is not None and name in params: if value is not None: raise TypeError( "cannot assign '{}' as parameter '{}' " "(nn.Parameter or None expected)".format(typename(value), name) ) try_remove() self.register_parameter(name, value) else: modules = self.__dict__.get("_modules") if isinstance(value, Module): if modules is None: raise AttributeError( "cannot assign module before Module.__init__() call" ) try_remove() modules[name] = value elif modules is not None and name in modules: if value is not None: raise TypeError( "cannot assign '{}' as child module '{}' " "(nn.Module or None expected)".format(typename(value), name) ) try_remove() modules[name] = value else: buffers = self.__dict__.get("_buffers") if buffers is not None and name in buffers: if value is not None and not isinstance(value, Tensor): raise TypeError( "cannot assign '{}' as buffer '{}' " "(Tensor or None expected)".format(typename(value), name) ) try_remove() buffers[name] = value else: super().__setattr__(name, value) def __delattr__(self, name): if name in self._parameters: del self._parameters[name] elif name in self._buffers: del self._buffers[name] self._non_persistent_buffers_set.discard(name) elif name in self._modules: del self._modules[name] else: super().__delattr__(name) def _named_members(self, get_members_fn, prefix="", recurse=True): r"""Helper method for yielding various names + members of modules.""" memo = set() modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)] for module_prefix, module in modules: members = get_members_fn(module) for k, v in members: if v is None or v in memo: continue memo.add(v) name = module_prefix + ("." if module_prefix else "") + k yield name, v
[docs] def parameters(self, recurse: bool = True) -> Iterator[Parameter]: r"""Returns an iterator over module parameters. This is typically passed to an optimizer. Args: recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Yields: Parameter: module parameter Example:: >>> for param in model.parameters(): >>> print(type(param), param.size()) <class 'Tensor'> (20L,) <class 'Tensor'> (20L, 1L, 5L, 5L) """ for _, param in self.named_parameters(recurse=recurse): yield param
[docs] def named_parameters( self, prefix: str = "", recurse: bool = True ) -> Iterator[Tuple[str, Parameter]]: r"""Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Args: prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Yields: (str, Parameter): Tuple containing the name and parameter Example:: >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size()) """ gen = self._named_members( lambda module: module._parameters.items(), prefix=prefix, recurse=recurse ) for elem in gen: yield elem
[docs] def buffers(self, recurse: bool = True) -> Iterator[Tensor]: r"""Returns an iterator over module buffers. Args: recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Yields: Tensor: module buffer Example:: >>> for buf in model.buffers(): >>> print(type(buf), buf.size()) <class 'Tensor'> (20L,) <class 'Tensor'> (20L, 1L, 5L, 5L) """ for _, buf in self.named_buffers(recurse=recurse): yield buf
[docs] def named_buffers( self, prefix: str = "", recurse: bool = True ) -> Iterator[Tuple[str, Tensor]]: r"""Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Args: prefix (str): prefix to prepend to all buffer names. recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Yields: (str, Tensor): Tuple containing the name and buffer Example:: >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size()) """ gen = self._named_members( lambda module: module._buffers.items(), prefix=prefix, recurse=recurse ) for elem in gen: yield elem
[docs] def children(self) -> Iterator["Module"]: r"""Returns an iterator over immediate children modules. Yields: Module: a child module """ for _, module in self.named_children(): yield module
[docs] def named_children(self) -> Iterator[Tuple[str, "Module"]]: r"""Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: (str, Module): Tuple containing a name and child module Example:: >>> for name, module in model.named_children(): >>> if name in ['conv4', 'conv5']: >>> print(module) """ memo = set() for name, module in self._modules.items(): if module is not None and module not in memo: memo.add(module) yield name, module
[docs] def modules(self) -> Iterator["Module"]: r"""Returns an iterator over all modules in the network. Yields: Module: a module in the network Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): print(idx, '->', m) 0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True) """ for _, module in self.named_modules(): yield module
[docs] def name_parameter_tensor(self): r"""Set the name of the parameter to tensor's name""" for name, param in self.named_parameters(): param.tensor()._attrs["name"] = name.replace(".", "_")
[docs] def named_modules( self, memo: Optional[Set["Module"]] = None, prefix: str = "", remove_duplicate: bool = True, ): r"""Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Args: memo: a memo to store the set of modules already added to the result prefix: a prefix that will be added to the name of the module remove_duplicate: whether to remove the duplicated module instances in the result or not Yields: (str, Module): Tuple of name and module Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) """ if memo is None: memo = set() if self not in memo: if remove_duplicate: memo.add(self) yield prefix, self for name, module in self._modules.items(): if module is None: continue submodule_prefix = prefix + ("." if prefix else "") + name for m in module.named_modules(memo, submodule_prefix, remove_duplicate): yield m
def _get_name(self): return self.__class__.__name__
[docs] def extra_repr(self) -> str: r"""Set the extra representation of the module To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. """ return ""
def __repr__(self): # We treat the extra repr like the sub-module, one item per line extra_lines = [] extra_repr = self.extra_repr() # empty string will be split into list [''] if extra_repr: extra_lines = extra_repr.split("\n") child_lines = [] for key, module in self._modules.items(): mod_str = repr(module) mod_str = _addindent(mod_str, 2) child_lines.append("(" + key + "): " + mod_str) lines = extra_lines + child_lines main_str = self._get_name() + "(" if lines: # simple one-liner info, which most builtin Modules will use if len(extra_lines) == 1 and not child_lines: main_str += extra_lines[0] else: main_str += "\n " + "\n ".join(lines) + "\n" main_str += ")" return main_str def __dir__(self): module_attrs = dir(self.__class__) attrs = list(self.__dict__.keys()) parameters = list(self._parameters.keys()) modules = list(self._modules.keys()) buffers = list(self._buffers.keys()) keys = module_attrs + attrs + parameters + modules + buffers # Eliminate attrs that are not legal Python variable names keys = [key for key in keys if not key[0].isdigit()] return sorted(keys)