# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections import namedtuple, OrderedDict
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union
from aitemplate.compiler.base import Tensor
from aitemplate.frontend.nn.parameter import Parameter
class _IncompatibleKeys(
namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])
):
def __repr__(self):
if not self.missing_keys and not self.unexpected_keys:
return "<All keys matched successfully>"
return super(_IncompatibleKeys, self).__repr__()
__str__ = __repr__
# Trick mypy into not applying contravariance rules to inputs by defining
# forward as a value, rather than a function. See also
# https://github.com/python/mypy/issues/8795
def _forward_unimplemented(self, *input: Any) -> None:
r"""Defines the computation performed at every call.
Should be overridden by all subclasses.
.. note::
Although the recipe for forward pass needs to be defined within
this function, one should call the :class:`Module` instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
"""
raise NotImplementedError(
f'Module [{type(self).__name__}] is missing the required "forward" function'
)
def typename(x):
if hasattr(x, "__class__"):
return x.__class__.__name__
else:
return str(type(x))
def _addindent(s_, numSpaces):
s = s_.split("\n")
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(numSpaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s
[docs]class Module:
r"""Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in
a tree structure. You can assign the submodules as regular attributes::
import nn as nn
import nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call :meth:`to`, etc.
.. note::
As per the example above, an ``__init__()`` call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or
evaluation mode.
:vartype training: bool
"""
dump_patches: bool = False
_version: int = 1
r"""This allows better BC support for :meth:`load_state_dict`. In
:meth:`state_dict`, the version number will be saved as in the attribute
`_metadata` of the returned state dict, and thus pickled. `_metadata` is a
dictionary with keys that follow the naming convention of state dict. See
``_load_from_state_dict`` on how to use this information in loading.
If new parameters/buffers are added/removed from a module, this number shall
be bumped, and the module's `_load_from_state_dict` method can compare the
version number and do appropriate changes if the state dict is from before
the change."""
_parameters: Dict[str, Optional[Parameter]]
_buffers: Dict[str, Optional[Tensor]]
_modules: Dict[str, Optional["Module"]]
def __init__(self) -> None:
"""
Calls super().__setattr__('a', a) instead of the typical self.a = a
to avoid Module.__setattr__ overhead. Module's __setattr__ has special
handling for parameters, submodules, and buffers but simply calls into
super().__setattr__ for all other attributes.
"""
super().__setattr__("_parameters", OrderedDict())
super().__setattr__("_buffers", OrderedDict())
super().__setattr__("_modules", OrderedDict())
forward: Callable[..., Any] = _forward_unimplemented
[docs] def register_buffer(
self, name: str, tensor: Optional[Tensor], persistent: bool = True
) -> None:
r"""Adds a buffer to the module.
This is typically used to register a buffer that should not to be
considered a model parameter. For example, BatchNorm's ``running_mean``
is not a parameter, but is part of the module's state. Buffers, by
default, are persistent and will be saved alongside parameters. This
behavior can be changed by setting :attr:`persistent` to ``False``. The
only difference between a persistent buffer and a non-persistent buffer
is that the latter will not be a part of this module's
:attr:`state_dict`.
Buffers can be accessed as attributes using given names.
Args:
name (str): name of the buffer. The buffer can be accessed
from this module using the given name
tensor (Tensor or None): buffer to be registered. If ``None``, then operations
that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,
the buffer is **not** included in the module's :attr:`state_dict`.
persistent (bool): whether the buffer is part of this module's
:attr:`state_dict`.
Example::
>>> self.register_buffer('running_mean', zeros(num_features))
"""
if "_buffers" not in self.__dict__:
raise AttributeError("cannot assign buffer before Module.__init__() call")
elif "." in name:
raise KeyError('buffer name can\'t contain "."')
elif name == "":
raise KeyError('buffer name can\'t be empty string ""')
elif hasattr(self, name) and name not in self._buffers:
raise KeyError("attribute '{}' already exists".format(name))
elif tensor is not None and not isinstance(tensor, Tensor):
raise TypeError(
"cannot assign '{}' object to buffer '{}' "
"(torch Tensor or None required)".format(typename(tensor), name)
)
else:
self._buffers[name] = tensor
if persistent:
self._non_persistent_buffers_set.discard(name)
else:
self._non_persistent_buffers_set.add(name)
[docs] def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
r"""Adds a parameter to the module.
The parameter can be accessed as an attribute using given name.
Args:
name (str): name of the parameter. The parameter can be accessed
from this module using the given name
param (Parameter or None): parameter to be added to the module. If
``None``, then operations that run on parameters, such as :attr:`cuda`,
are ignored. If ``None``, the parameter is **not** included in the
module's :attr:`state_dict`.
"""
if "_parameters" not in self.__dict__:
raise AttributeError(
"cannot assign parameter before Module.__init__() call"
)
elif "." in name:
raise KeyError('parameter name can\'t contain "."')
elif name == "":
raise KeyError('parameter name can\'t be empty string ""')
elif hasattr(self, name) and name not in self._parameters:
raise KeyError("attribute '{}' already exists".format(name))
if param is None:
self._parameters[name] = None
elif not isinstance(param, Parameter):
raise TypeError(
"cannot assign '{}' object to parameter '{}' "
"(nn.Parameter or None required)".format(typename(param), name)
)
else:
self._parameters[name] = param
[docs] def add_module(self, name: str, module: Optional["Module"]) -> None:
r"""Adds a child module to the current module.
The module can be accessed as an attribute using the given name.
Args:
name (str): name of the child module. The child module can be
accessed from this module using the given name
module (Module): child module to be added to the module.
"""
if not isinstance(module, Module) and module is not None:
raise TypeError("{} is not a Module subclass".format(typename(module)))
elif hasattr(self, name) and name not in self._modules:
raise KeyError("attribute '{}' already exists".format(name))
elif "." in name:
raise KeyError('module name can\'t contain ".", got: {}'.format(name))
elif name == "":
raise KeyError('module name can\'t be empty string ""')
self._modules[name] = module
[docs] def register_module(self, name: str, module: Optional["Module"]) -> None:
r"""Alias for :func:`add_module`."""
self.add_module(name, module)
[docs] def get_submodule(self, target: str) -> "Module":
"""
Returns the submodule given by ``target`` if it exists,
otherwise throws an error.
For example, let's say you have an ``nn.Module`` ``A`` that
looks like this:
.. code-block:: text
A(
(net_b): Module(
(net_c): Module(
(conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
)
(linear): Linear(in_features=100, out_features=200, bias=True)
)
)
(The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested
submodule ``net_b``, which itself has two submodules ``net_c``
and ``linear``. ``net_c`` then has a submodule ``conv``.)
To check whether or not we have the ``linear`` submodule, we
would call ``get_submodule("net_b.linear")``. To check whether
we have the ``conv`` submodule, we would call
``get_submodule("net_b.net_c.conv")``.
The runtime of ``get_submodule`` is bounded by the degree
of module nesting in ``target``. A query against
``named_modules`` achieves the same result, but it is O(N) in
the number of transitive modules. So, for a simple check to see
if some submodule exists, ``get_submodule`` should always be
used.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
Returns:
nn.Module: The submodule referenced by ``target``
Raises:
AttributeError: If the target string references an invalid
path or resolves to something that is not an
``nn.Module``
"""
if target == "":
return self
atoms: List[str] = target.split(".")
mod: Module = self
for item in atoms:
if not hasattr(mod, item):
raise AttributeError(
mod._get_name() + " has no " "attribute `" + item + "`"
)
mod = getattr(mod, item)
if not isinstance(mod, Module):
raise AttributeError("`" + item + "` is not " "an nn.Module")
return mod
[docs] def get_parameter(self, target: str) -> "Parameter":
"""
Returns the parameter given by ``target`` if it exists,
otherwise throws an error.
See the docstring for ``get_submodule`` for a more detailed
explanation of this method's functionality as well as how to
correctly specify ``target``.
Args:
target: The fully-qualified string name of the Parameter
to look for. (See ``get_submodule`` for how to specify a
fully-qualified string.)
Returns:
nn.Parameter: The Parameter referenced by ``target``
Raises:
AttributeError: If the target string references an invalid
path or resolves to something that is not an
``nn.Parameter``
"""
module_path, _, param_name = target.rpartition(".")
mod: Module = self.get_submodule(module_path)
if not hasattr(mod, param_name):
raise AttributeError(
mod._get_name() + " has no attribute `" + param_name + "`"
)
param: Parameter = getattr(mod, param_name)
if not isinstance(param, Parameter):
raise AttributeError("`" + param_name + "` is not an " "nn.Parameter")
return param
[docs] def get_buffer(self, target: str) -> "Tensor":
"""
Returns the buffer given by ``target`` if it exists,
otherwise throws an error.
See the docstring for ``get_submodule`` for a more detailed
explanation of this method's functionality as well as how to
correctly specify ``target``.
Args:
target: The fully-qualified string name of the buffer
to look for. (See ``get_submodule`` for how to specify a
fully-qualified string.)
Returns:
Tensor: The buffer referenced by ``target``
Raises:
AttributeError: If the target string references an invalid
path or resolves to something that is not a
buffer
"""
module_path, _, buffer_name = target.rpartition(".")
mod: Module = self.get_submodule(module_path)
if not hasattr(mod, buffer_name):
raise AttributeError(
mod._get_name() + " has no attribute `" + buffer_name + "`"
)
buffer: Tensor = getattr(mod, buffer_name)
if buffer_name not in mod._buffers:
raise AttributeError("`" + buffer_name + "` is not a buffer")
return buffer
def _call_impl(self, *input, **kwargs):
forward_call = self.forward
return forward_call(*input, **kwargs)
__call__: Callable[..., Any] = _call_impl
def __getattr__(self, name: str) -> Union[Tensor, "Module"]:
if "_parameters" in self.__dict__:
_parameters = self.__dict__["_parameters"]
if name in _parameters:
return _parameters[name]
if "_buffers" in self.__dict__:
_buffers = self.__dict__["_buffers"]
if name in _buffers:
return _buffers[name]
if "_modules" in self.__dict__:
modules = self.__dict__["_modules"]
if name in modules:
return modules[name]
raise AttributeError(
"'{}' object has no attribute '{}'".format(type(self).__name__, name)
)
def __setattr__(self, name: str, value: Union[Tensor, "Module"]) -> None:
def try_remove():
dicts_or_sets = [
self.__dict__,
self.__dict__.get("_parameters"),
self.__dict__.get("_buffers"),
self.__dict__.get("_modules"),
]
for d in dicts_or_sets:
if name in d:
if isinstance(d, dict):
d.pop(name, None)
else:
d.discard(name)
params = self.__dict__.get("_parameters")
if isinstance(value, Parameter):
if params is None:
raise AttributeError(
"cannot assign parameters before Module.__init__() call"
)
try_remove()
self.register_parameter(name, value)
elif params is not None and name in params:
if value is not None:
raise TypeError(
"cannot assign '{}' as parameter '{}' "
"(nn.Parameter or None expected)".format(typename(value), name)
)
try_remove()
self.register_parameter(name, value)
else:
modules = self.__dict__.get("_modules")
if isinstance(value, Module):
if modules is None:
raise AttributeError(
"cannot assign module before Module.__init__() call"
)
try_remove()
modules[name] = value
elif modules is not None and name in modules:
if value is not None:
raise TypeError(
"cannot assign '{}' as child module '{}' "
"(nn.Module or None expected)".format(typename(value), name)
)
try_remove()
modules[name] = value
else:
buffers = self.__dict__.get("_buffers")
if buffers is not None and name in buffers:
if value is not None and not isinstance(value, Tensor):
raise TypeError(
"cannot assign '{}' as buffer '{}' "
"(Tensor or None expected)".format(typename(value), name)
)
try_remove()
buffers[name] = value
else:
super().__setattr__(name, value)
def __delattr__(self, name):
if name in self._parameters:
del self._parameters[name]
elif name in self._buffers:
del self._buffers[name]
self._non_persistent_buffers_set.discard(name)
elif name in self._modules:
del self._modules[name]
else:
super().__delattr__(name)
def _named_members(self, get_members_fn, prefix="", recurse=True):
r"""Helper method for yielding various names + members of modules."""
memo = set()
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
for module_prefix, module in modules:
members = get_members_fn(module)
for k, v in members:
if v is None or v in memo:
continue
memo.add(v)
name = module_prefix + ("." if module_prefix else "") + k
yield name, v
[docs] def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
r"""Returns an iterator over module parameters.
This is typically passed to an optimizer.
Args:
recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
Yields:
Parameter: module parameter
Example::
>>> for param in model.parameters():
>>> print(type(param), param.size())
<class 'Tensor'> (20L,)
<class 'Tensor'> (20L, 1L, 5L, 5L)
"""
for _, param in self.named_parameters(recurse=recurse):
yield param
[docs] def named_parameters(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, Parameter]]:
r"""Returns an iterator over module parameters, yielding both the
name of the parameter as well as the parameter itself.
Args:
prefix (str): prefix to prepend to all parameter names.
recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
Yields:
(str, Parameter): Tuple containing the name and parameter
Example::
>>> for name, param in self.named_parameters():
>>> if name in ['bias']:
>>> print(param.size())
"""
gen = self._named_members(
lambda module: module._parameters.items(), prefix=prefix, recurse=recurse
)
for elem in gen:
yield elem
[docs] def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
r"""Returns an iterator over module buffers.
Args:
recurse (bool): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module.
Yields:
Tensor: module buffer
Example::
>>> for buf in model.buffers():
>>> print(type(buf), buf.size())
<class 'Tensor'> (20L,)
<class 'Tensor'> (20L, 1L, 5L, 5L)
"""
for _, buf in self.named_buffers(recurse=recurse):
yield buf
[docs] def named_buffers(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, Tensor]]:
r"""Returns an iterator over module buffers, yielding both the
name of the buffer as well as the buffer itself.
Args:
prefix (str): prefix to prepend to all buffer names.
recurse (bool): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module.
Yields:
(str, Tensor): Tuple containing the name and buffer
Example::
>>> for name, buf in self.named_buffers():
>>> if name in ['running_var']:
>>> print(buf.size())
"""
gen = self._named_members(
lambda module: module._buffers.items(), prefix=prefix, recurse=recurse
)
for elem in gen:
yield elem
[docs] def children(self) -> Iterator["Module"]:
r"""Returns an iterator over immediate children modules.
Yields:
Module: a child module
"""
for _, module in self.named_children():
yield module
[docs] def named_children(self) -> Iterator[Tuple[str, "Module"]]:
r"""Returns an iterator over immediate children modules, yielding both
the name of the module as well as the module itself.
Yields:
(str, Module): Tuple containing a name and child module
Example::
>>> for name, module in model.named_children():
>>> if name in ['conv4', 'conv5']:
>>> print(module)
"""
memo = set()
for name, module in self._modules.items():
if module is not None and module not in memo:
memo.add(module)
yield name, module
[docs] def modules(self) -> Iterator["Module"]:
r"""Returns an iterator over all modules in the network.
Yields:
Module: a module in the network
Note:
Duplicate modules are returned only once. In the following
example, ``l`` will be returned only once.
Example::
>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.modules()):
print(idx, '->', m)
0 -> Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
)
1 -> Linear(in_features=2, out_features=2, bias=True)
"""
for _, module in self.named_modules():
yield module
[docs] def name_parameter_tensor(self):
r"""Set the name of the parameter to tensor's name"""
for name, param in self.named_parameters():
param.tensor()._attrs["name"] = name.replace(".", "_")
[docs] def named_modules(
self,
memo: Optional[Set["Module"]] = None,
prefix: str = "",
remove_duplicate: bool = True,
):
r"""Returns an iterator over all modules in the network, yielding
both the name of the module as well as the module itself.
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
Yields:
(str, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following
example, ``l`` will be returned only once.
Example::
>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.named_modules()):
print(idx, '->', m)
0 -> ('', Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
))
1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
"""
if memo is None:
memo = set()
if self not in memo:
if remove_duplicate:
memo.add(self)
yield prefix, self
for name, module in self._modules.items():
if module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
for m in module.named_modules(memo, submodule_prefix, remove_duplicate):
yield m
def _get_name(self):
return self.__class__.__name__
def __repr__(self):
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = self.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split("\n")
child_lines = []
for key, module in self._modules.items():
mod_str = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append("(" + key + "): " + mod_str)
lines = extra_lines + child_lines
main_str = self._get_name() + "("
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str
def __dir__(self):
module_attrs = dir(self.__class__)
attrs = list(self.__dict__.keys())
parameters = list(self._parameters.keys())
modules = list(self._modules.keys())
buffers = list(self._buffers.keys())
keys = module_attrs + attrs + parameters + modules + buffers
# Eliminate attrs that are not legal Python variable names
keys = [key for key in keys if not key[0].isdigit()]
return sorted(keys)