concat

Module for supporting tracing of concat operations.

Note that this implementation assumes that the concat operation/symbol is the _only_ searchable symbol within dependent symbols. This enables us to simplify partial dependencies that would otherwise arise from having multiple concats linked together.

There is one small exception to this: whenever one concat depends on another concat, we can disable the independent concat and simplify the representation this way.

However, truly linked concats cannot be handled, e.g., torch.cat([x1,x2], dim=1) + torch.cat([y1, y2], dim=1), as there is no way of disabling one but not the other.

Classes

ConcatSymbol

Symbol storing an ordered list of linked symbolic inputs to concat.

ConcatNodeProcessor

Node for handling concat specific tracing logic.

class ConcatNodeProcessor

Bases: NodeProcessor

Node for handling concat specific tracing logic.

__init__(*args, **kwargs)

Init.

Return type:

None

is_special_node(node, target)

Return whether node is a concat node.

Parameters:
  • node (Node) –

  • target (Module | Callable) –

Return type:

bool

post_process()

Revert back to original symbols.

Return type:

None

process(node, id, input_nodes)

Make all inputs for the concat searchable.

Parameters:
  • node (Node) –

  • id (int) –

  • input_nodes (List[Node]) –

Return type:

None

reset()

Reset state.

Return type:

None

class ConcatSymbol

Bases: Symbol

Symbol storing an ordered list of linked symbolic inputs to concat.

class Input

Bases: Symbol

Special Symbol to represent an input to a ConcatSymbol.

This symbol is an augmented version of the regular Symbol to handle the interaction with the concat operation and is used to monkey-patch the original symbol.

__init__(*args, **kwargs)

Constructor.

property concat_sym: ConcatSymbol

Return concat symbl.

static convert(orig_sym, cat_dim)

Modify and convert the sym in-place to a valid ConcatSymbol.Input.

Parameters:
  • orig_sym (Symbol) –

  • cat_dim (int) –

Return type:

Input | ConcatSymbol

create_linked_copy()

Get a linked deepcopy of self that is monkey-patched to the original symbol class.

Return type:

Symbol

Link self to other symbol by simply disabling both.

ConcatSymbol.Input can never be linked to another symbol, but other symbols can be linked to it without disabling it!

Parameters:

other (Symbol) –

Return type:

None

__init__(symbols, cl_type=CLType.NONE, elastic_dims=None)

Initializes Symbol from input symbols.

Parameters:
  • symbols (List[Symbol]) –

  • cl_type (CLType) –

  • elastic_dims (Set[int] | None) –

disable(_memo=None)

Disable all symbols including input symbols.

We handle input symbols by fake adding them to the dependency list. Note that the dependency list is cleared at the end anyway - so this is fine.

Parameters:

_memo (Set[Symbol] | None) –

Return type:

None

property input_syms: List[Input | ConcatSymbol]

Return symbols.

property is_constant: bool

Return indicator whether symbol is constant.

Unlike a regular symbol where this is determined based on a manually set flag, is_constant for concat is set according to whether ALL input symbols are constant.

property is_searchable: bool

Return indicator whether symbol is searchable.

Unlike a regular symbol where this is determined based on a manually set flag, is_searchable for concat is set according to whether ANY input symbols are searchable.

Link self to other symbol.

Parameters:

other (Symbol) –

Return type:

None