symbols
Utilities to describe symbols found in common torch modules.
Classes
A symbolic parameter ( |
|
A class to hold the symbolic representations of a model. |
|
A simple class to hold relevant information about the symbolic nature of a given module. |
- class SymInfo
Bases:
object
A simple class to hold relevant information about the symbolic nature of a given module.
- __init__(is_shape_preserving=False, **kwargs)
Initialize the instance with the given symbolic information.
- Parameters:
is_shape_preserving (bool) –
kwargs (Symbol) –
- property is_shape_preserving: bool
Return indicator whether module is shape-preserving.
- class SymMap
Bases:
object
A class to hold the symbolic representations of a model.
- __init__(model)
Initialize with the desired module.
- Return type:
None
- add_sym_info(key, sym_info)
Manually add a model’s module’s sym_info.
- Parameters:
key (Module) –
sym_info (SymInfo) –
- Return type:
None
- get_symbol(mod, name)
Get symbol from the given module with the given name.
- Parameters:
mod (Module) –
name (str) –
- Return type:
- is_shape_preserving(key)
Return whether the symbolic module is shape preserving.
- Parameters:
key (Module) –
- Return type:
bool
- items()
Return an iterator over the dictionary.
- Return type:
Generator[Tuple[Module, Dict[str, Symbol]], None, None]
- named_modules()
Yield the name (from self._mod_to_name) and the associated module.
- Return type:
Generator[Tuple[str, Module], None, None]
- named_sym_dicts()
Yield the name (from self._mod_to_name) and the associated symbolic module.
- Return type:
Generator[Tuple[str, Dict[str, Symbol]], None, None]
- named_symbols(key=None, free=None, dynamic=None, searchable=None, constant=None)
Yield the name and symbol of symbols in either all symbolic modules or a specific one.
- Parameters:
key (Module | None) – The module to get symbols from. If not provided, recursive through all modules.
free (bool | None) – Whether to include free symbols.
dynamic (bool | None) – Whether to include dynamic symbols.
searchable (bool | None) – Whether to include searchable symbols.
constant (bool | None) – Whether to include constant symbols.
- Yields:
(name, Symbol) – Tuple containing the name and symbol.
- Return type:
Generator[Tuple[str, Symbol], None, None]
Default behavior is to iterate over free, dynamic, searchable, or constant symbols. Set args accordingly to only iterate over some. When either
free
,dynamic
,searchable
, orconstant
is set toTrue
, only symbols of that type are iterated over. If either offree
,dynamic
,searchable
, orconstant
is set toFalse
, symbols of that type are skipped over.
- pop(key)
Remove the given module from the dictionary and return its symbolic representation.
- Parameters:
key (Module) –
- Return type:
Dict[str, Symbol]
- prune()
Prune the map by removing modules with constant-only symbols.
- Return type:
None
- classmethod register(nn_cls, is_explicit_leaf=True)
Use this to register a function that defines the symbols for a given nn module.
- Parameters:
nn_cls (Type[Module] | List[Type[Module]]) – The nn module class for which the function is registered.
is_explicit_leaf (bool) – Whether the module is an explicit leaf, i.e., whether it should be treated as leaf during tracing.
- Returns:
A decorator that registers the given function for the given nn module class.
- Return type:
Callable[[Callable[[Module], SymInfo]], Callable[[Module], SymInfo]]
An example for registering the symbolic information of a module is shown below:
@SymMap.register(nn.Linear) def get_linear_sym_info(mod: nn.Linear) -> SymInfo: in_features = Symbol(cl_type=Symbol.CLType.INCOMING, elastic_dims={-1}) out_features = Symbol( is_searchable=True, cl_type=Symbol.CLType.OUTGOING, elastic_dims={-1} ) return SymInfo(in_features=in_features, out_features=out_features)
- set_symbol(mod, name, symbol)
Set symbol from the given module with the given name.
- Parameters:
mod (Module) –
name (str) –
symbol (Symbol) –
- Return type:
None
- classmethod unregister(nn_cls)
Unregister module that previously has been registered.
It throws a KeyError if the module is not registered.
- Parameters:
nn_cls (Type[Module]) –
- Return type:
None
- class Symbol
Bases:
object
A symbolic parameter (
Symbol
) of aSymModule
.An example of a Symbol could be the kernel_size of a conv.
Note that a symbol can have the following states (mutually exclusive): - free: the symbol is not bound to any value - searchable: the symbol is free and can be searched over - constant: the symbol’s value cannot be changed - dynamic: the symbol’s value is determined by its parent symbol
In addition, a symbol can exhibit properties related to its cross-layer significance: - incoming: the symbol depends on the input tensor to the module - outgoing: the module’s output tensor depends on the symbol - none: the symbol is not cross-layer significant (only affects the internals of the module)
Based on these basic properties, we define a few useful compound properties: - is_cross_layer: the symbol is incoming or outgoing - is_dangling: the symbol is free and cross_layer
- __init__(is_searchable=False, is_sortable=True, cl_type=CLType.NONE, elastic_dims=None)
Initializes Symbol with tracing-relevant information.
- Parameters:
is_searchable (bool) –
is_sortable (bool) –
cl_type (CLType) –
elastic_dims (Set[int] | None) –
- disable(_memo=None)
Disable symbol and mark it as constant together with its whole dependency tree via DFS.
After this call,
is_constant == True
.- Parameters:
_memo (Set[Symbol] | None) –
- Return type:
None
- property elastic_dims: Set[int]
Returns the set of tensor dimensions that refer to this symbol.
Note that this refers to the dimension of the tensor incoming or outgoing from the layer, not the parameters of the module. E.g. for a Conv2d layer, this refers to the “C” in “NCHW” of the incoming/outgoing tensor.
This must be defined for incoming/outgoing symbols, and must be empty for all others.
Also note that this is a set of dimension, although only one actual dimension per Symbol can be elastic. The added flexibility of using a set instead of a single tensor is to enable describing the same dimension in different indexing notations (e.g. {1,-3} for Conv2d where both 1 and -3 refer to the “C” dimension in “NCHW”).
- property is_constant: bool
Return indicator whether symbol is constant.
- property is_cross_layer: bool
Return indicator whether symbol is cross-layer.
- property is_dangling: bool
Return indicator whether symbol is dangling (cross-layer and free).
- property is_dynamic: bool
Return indicator whether symbol is dynamic.
- property is_free: bool
Return indicator whether symbol is free.
- property is_incoming: bool
Return indicator whether symbol is cross-layer incoming.
- property is_outgoing: bool
Return indicator whether symbol is cross-layer outgoing.
- property is_searchable: bool
Return indicator whether symbol is searchable.
- property is_sortable: bool
Return indicator whether symbols in dependency tree are sortable.