symbols
Utilities to describe symbols found in common torch modules.
Classes
| A simple class to hold relevant information about the symbolic nature of a given module. | |
| A class to hold the symbolic representations of a model. | |
| A symbolic parameter ( | 
- class SymInfo
- Bases: - object- A simple class to hold relevant information about the symbolic nature of a given module. - __init__(is_shape_preserving=False, **kwargs)
- Initialize the instance with the given symbolic information. - Parameters:
- is_shape_preserving (bool) 
- kwargs (Symbol) 
 
 
 - property is_shape_preserving: bool
- Return indicator whether module is shape-preserving. 
 
- class SymMap
- Bases: - object- A class to hold the symbolic representations of a model. - __init__(model)
- Initialize with the desired module. - Return type:
- None 
 
 - add_sym_info(key, sym_info)
- Manually add a model’s module’s sym_info. - Parameters:
- key (Module) 
- sym_info (SymInfo) 
 
- Return type:
- None 
 
 - get_symbol(mod, name)
- Get symbol from the given module with the given name. - Parameters:
- mod (Module) 
- name (str) 
 
- Return type:
 
 - is_shape_preserving(key)
- Return whether the symbolic module is shape preserving. - Parameters:
- key (Module) 
- Return type:
- bool 
 
 - items()
- Return an iterator over the dictionary. - Return type:
- Generator[tuple[Module, dict[str, Symbol]], None, None] 
 
 - named_modules()
- Yield the name (from self._mod_to_name) and the associated module. - Return type:
- Generator[tuple[str, Module], None, None] 
 
 - named_sym_dicts()
- Yield the name (from self._mod_to_name) and the associated symbolic module. - Return type:
- Generator[tuple[str, dict[str, Symbol]], None, None] 
 
 - named_symbols(key=None, free=None, dynamic=None, searchable=None, constant=None)
- Yield the name and symbol of symbols in either all symbolic modules or a specific one. - Parameters:
- key (Module | None) – The module to get symbols from. If not provided, recursive through all modules. 
- free (bool | None) – Whether to include free symbols. 
- dynamic (bool | None) – Whether to include dynamic symbols. 
- searchable (bool | None) – Whether to include searchable symbols. 
- constant (bool | None) – Whether to include constant symbols. 
 
- Yields:
- (name, Symbol) – Tuple containing the name and symbol. 
- Return type:
- Generator[tuple[str, Symbol], None, None] 
 - Default behavior is to iterate over free, dynamic, searchable, or constant symbols. Set args accordingly to only iterate over some. When either - free,- dynamic,- searchable, or- constantis set to- True, only symbols of that type are iterated over. If either of- free,- dynamic,- searchable, or- constantis set to- False, symbols of that type are skipped over.
 - pop(key)
- Remove the given module from the dictionary and return its symbolic representation. - Parameters:
- key (Module) 
- Return type:
- dict[str, Symbol] 
 
 - prune()
- Prune the map by removing modules with constant-only symbols. - Return type:
- None 
 
 - classmethod register(nn_cls, is_explicit_leaf=True)
- Use this to register a function that defines the symbols for a given nn module. - Parameters:
- nn_cls (type[Module] | list[type[Module]]) – The nn module class for which the function is registered. 
- is_explicit_leaf (bool) – Whether the module is an explicit leaf, i.e., whether it should be treated as leaf during tracing. 
 
- Returns:
- A decorator that registers the given function for the given nn module class. 
- Return type:
- Callable[[Callable[[Module], SymInfo]], Callable[[Module], SymInfo]] 
 - An example for registering the symbolic information of a module is shown below: - @SymMap.register(nn.Linear) def get_linear_sym_info(mod: nn.Linear) -> SymInfo: in_features = Symbol(cl_type=Symbol.CLType.INCOMING, elastic_dims={-1}) out_features = Symbol( is_searchable=True, cl_type=Symbol.CLType.OUTGOING, elastic_dims={-1} ) return SymInfo(in_features=in_features, out_features=out_features) 
 - set_symbol(mod, name, symbol)
- Set symbol from the given module with the given name. - Parameters:
- mod (Module) 
- name (str) 
- symbol (Symbol) 
 
- Return type:
- None 
 
 - classmethod unregister(nn_cls)
- Unregister module that previously has been registered. - It throws a KeyError if the module is not registered. - Parameters:
- nn_cls (type[Module]) 
- Return type:
- None 
 
 
- class Symbol
- Bases: - object- A symbolic parameter ( - Symbol) of a- SymModule.- An example of a Symbol could be the kernel_size of a conv. - Note that a symbol can have the following states (mutually exclusive): - free: the symbol is not bound to any value - searchable: the symbol is free and can be searched over - constant: the symbol’s value cannot be changed - dynamic: the symbol’s value is determined by its parent symbol - In addition, a symbol can exhibit properties related to its cross-layer significance: - incoming: the symbol depends on the input tensor to the module - outgoing: the module’s output tensor depends on the symbol - none: the symbol is not cross-layer significant (only affects the internals of the module) - Based on these basic properties, we define a few useful compound properties: - is_cross_layer: the symbol is incoming or outgoing - is_dangling: the symbol is free and cross_layer - __init__(is_searchable=False, is_sortable=True, cl_type=CLType.NONE, elastic_dims=None)
- Initializes Symbol with tracing-relevant information. - Parameters:
- is_searchable (bool) 
- is_sortable (bool) 
- cl_type (CLType) 
- elastic_dims (set[int] | None) 
 
 
 - disable(_memo=None)
- Disable symbol and mark it as constant together with its whole dependency tree via DFS. - After this call, - is_constant == True.- Parameters:
- _memo (set[Symbol] | None) 
- Return type:
- None 
 
 - property elastic_dims: set[int]
- Returns the set of tensor dimensions that refer to this symbol. - Note that this refers to the dimension of the tensor incoming or outgoing from the layer, not the parameters of the module. E.g. for a Conv2d layer, this refers to the “C” in “NCHW” of the incoming/outgoing tensor. - This must be defined for incoming/outgoing symbols, and must be empty for all others. - Also note that this is a set of dimension, although only one actual dimension per Symbol can be elastic. The added flexibility of using a set instead of a single tensor is to enable describing the same dimension in different indexing notations (e.g. {1,-3} for Conv2d where both 1 and -3 refer to the “C” dimension in “NCHW”). 
 - property is_constant: bool
- Return indicator whether symbol is constant. 
 - property is_cross_layer: bool
- Return indicator whether symbol is cross-layer. 
 - property is_dangling: bool
- Return indicator whether symbol is dangling (cross-layer and free). 
 - property is_dynamic: bool
- Return indicator whether symbol is dynamic. 
 - property is_free: bool
- Return indicator whether symbol is free. 
 - property is_incoming: bool
- Return indicator whether symbol is cross-layer incoming. 
 - property is_outgoing: bool
- Return indicator whether symbol is cross-layer outgoing. 
 - property is_searchable: bool
- Return indicator whether symbol is searchable. 
 - property is_sortable: bool
- Return indicator whether symbols in dependency tree are sortable.