symbols

Utilities to describe symbols found in common torch modules.

Classes

Symbol

A symbolic parameter (Symbol) of a SymModule.

SymMap

A class to hold the symbolic representations of a model.

SymInfo

A simple class to hold relevant information about the symbolic nature of a given module.

class SymInfo

Bases: object

A simple class to hold relevant information about the symbolic nature of a given module.

SymDict

alias of Dict[str, Symbol]

__init__(is_shape_preserving=False, **kwargs)

Initialize the instance with the given symbolic information.

Parameters:
  • is_shape_preserving (bool) –

  • kwargs (Symbol) –

property is_shape_preserving: bool

Return indicator whether module is shape-preserving.

class SymMap

Bases: object

A class to hold the symbolic representations of a model.

SymRegisterFunc

alias of Callable[[Module], SymInfo]

__init__(model)

Initialize with the desired module.

Return type:

None

add_sym_info(key, sym_info)

Manually add a model’s module’s sym_info.

Parameters:
  • key (Module) –

  • sym_info (SymInfo) –

Return type:

None

get_symbol(mod, name)

Get symbol from the given module with the given name.

Parameters:
  • mod (Module) –

  • name (str) –

Return type:

Symbol

is_shape_preserving(key)

Return whether the symbolic module is shape preserving.

Parameters:

key (Module) –

Return type:

bool

items()

Return an iterator over the dictionary.

Return type:

Generator[Tuple[Module, Dict[str, Symbol]], None, None]

named_modules()

Yield the name (from self._mod_to_name) and the associated module.

Return type:

Generator[Tuple[str, Module], None, None]

named_sym_dicts()

Yield the name (from self._mod_to_name) and the associated symbolic module.

Return type:

Generator[Tuple[str, Dict[str, Symbol]], None, None]

named_symbols(key=None, free=None, dynamic=None, searchable=None, constant=None)

Yield the name and symbol of symbols in either all symbolic modules or a specific one.

Parameters:
  • key (Module | None) – The module to get symbols from. If not provided, recursive through all modules.

  • free (bool | None) – Whether to include free symbols.

  • dynamic (bool | None) – Whether to include dynamic symbols.

  • searchable (bool | None) – Whether to include searchable symbols.

  • constant (bool | None) – Whether to include constant symbols.

Yields:

(name, Symbol) – Tuple containing the name and symbol.

Return type:

Generator[Tuple[str, Symbol], None, None]

Default behavior is to iterate over free, dynamic, searchable, or constant symbols. Set args accordingly to only iterate over some. When either free, dynamic, searchable, or constant is set to True, only symbols of that type are iterated over. If either of free, dynamic, searchable, or constant is set to False, symbols of that type are skipped over.

pop(key)

Remove the given module from the dictionary and return its symbolic representation.

Parameters:

key (Module) –

Return type:

Dict[str, Symbol]

prune()

Prune the map by removing modules with constant-only symbols.

Return type:

None

classmethod register(nn_cls, is_explicit_leaf=True)

Use this to register a function that defines the symbols for a given nn module.

Parameters:
  • nn_cls (Type[Module] | List[Type[Module]]) – The nn module class for which the function is registered.

  • is_explicit_leaf (bool) – Whether the module is an explicit leaf, i.e., whether it should be treated as leaf during tracing.

Returns:

A decorator that registers the given function for the given nn module class.

Return type:

Callable[[Callable[[Module], SymInfo]], Callable[[Module], SymInfo]]

An example for registering the symbolic information of a module is shown below:

@SymMap.register(nn.Linear)
def get_linear_sym_info(mod: nn.Linear) -> SymInfo:
    in_features = Symbol(cl_type=Symbol.CLType.INCOMING, elastic_dims={-1})
    out_features = Symbol(is_searchable=True, cl_type=Symbol.CLType.OUTGOING, elastic_dims={-1})
    return SymInfo(in_features=in_features, out_features=out_features)
set_symbol(mod, name, symbol)

Set symbol from the given module with the given name.

Parameters:
  • mod (Module) –

  • name (str) –

  • symbol (Symbol) –

Return type:

None

classmethod unregister(nn_cls)

Unregister module that previously has been registered.

It throws a KeyError if the module is not registered.

Parameters:

nn_cls (Type[Module]) –

Return type:

None

class Symbol

Bases: object

A symbolic parameter (Symbol) of a SymModule.

An example of a Symbol could be the kernel_size of a conv.

Note that a symbol can have the following states (mutually exclusive): - free: the symbol is not bound to any value - searchable: the symbol is free and can be searched over - constant: the symbol’s value cannot be changed - dynamic: the symbol’s value is determined by its parent symbol

In addition, a symbol can exhibit properties related to its cross-layer significance: - incoming: the symbol depends on the input tensor to the module - outgoing: the module’s output tensor depends on the symbol - none: the symbol is not cross-layer significant (only affects the internals of the module)

Based on these basic properties, we define a few useful compound properties: - is_cross_layer: the symbol is incoming or outgoing - is_dangling: the symbol is free and cross_layer

class CLType

Bases: Enum

Cross-layer type for the symbol.

INCOMING = 2
NONE = 1
OUTGOING = 3
__init__(is_searchable=False, is_sortable=True, cl_type=CLType.NONE, elastic_dims=None)

Initializes Symbol with tracing-relevant information.

Parameters:
  • is_searchable (bool) –

  • is_sortable (bool) –

  • cl_type (CLType) –

  • elastic_dims (Set[int] | None) –

property cl_type: CLType

Return the cross-layer type of the symbol.

disable(_memo=None)

Disable symbol and mark it as constant together with its whole dependency tree via DFS.

After this call, is_constant == True.

Parameters:

_memo (Set[Symbol] | None) –

Return type:

None

property elastic_dims: Set[int]

Returns the set of tensor dimensions that refer to this symbol.

Note that this refers to the dimension of the tensor incoming or outgoing from the layer, not the parameters of the module. E.g. for a Conv2d layer, this refers to the “C” in “NCHW” of the incoming/outgoing tensor.

This must be defined for incoming/outgoing symbols, and must be empty for all others.

Also note that this is a set of dimension, although only one actual dimension per Symbol can be elastic. The added flexibility of using a set instead of a single tensor is to enable describing the same dimension in different indexing notations (e.g. {1,-3} for Conv2d where both 1 and -3 refer to the “C” dimension in “NCHW”).

property is_constant: bool

Return indicator whether symbol is constant.

property is_cross_layer: bool

Return indicator whether symbol is cross-layer.

property is_dangling: bool

Return indicator whether symbol is dangling (cross-layer and free).

property is_dynamic: bool

Return indicator whether symbol is dynamic.

property is_free: bool

Return indicator whether symbol is free.

property is_incoming: bool

Return indicator whether symbol is cross-layer incoming.

property is_outgoing: bool

Return indicator whether symbol is cross-layer outgoing.

property is_searchable: bool

Return indicator whether symbol is searchable.

property is_sortable: bool

Return indicator whether symbols in dependency tree are sortable.

Register a parent symbol, i.e., make this symbol dependent on the parent.

Parameters:

sp_parent (Symbol) –

Return type:

None

property parent: Symbol | None

Return the parent symbol.