concat
Module for supporting tracing of concat operations.
Note that this implementation assumes that the concat operation/symbol is the _only_ searchable symbol within dependent symbols. This enables us to simplify partial dependencies that would otherwise arise from having multiple concats linked together.
There is one small exception to this: whenever one concat depends on another concat, we can disable the independent concat and simplify the representation this way.
However, truly linked concats cannot be handled, e.g.,
torch.cat([x1,x2], dim=1) + torch.cat([y1, y2], dim=1)
,
as there is no way of disabling one but not the other.
Classes
Symbol storing an ordered list of linked symbolic inputs to concat. |
|
Node for handling concat specific tracing logic. |
- class ConcatNodeProcessor
Bases:
NodeProcessor
Node for handling concat specific tracing logic.
- __init__(*args, **kwargs)
Init.
- Return type:
None
- is_special_node(node, target)
Return whether node is a concat node.
- Parameters:
node (Node) –
target (Module | Callable) –
- Return type:
bool
- post_process()
Revert back to original symbols.
- Return type:
None
- process(node, id, input_nodes)
Make all inputs for the concat searchable.
- Parameters:
node (Node) –
id (int) –
input_nodes (List[Node]) –
- Return type:
None
- reset()
Reset state.
- Return type:
None
- class ConcatSymbol
Bases:
Symbol
Symbol storing an ordered list of linked symbolic inputs to concat.
- class Input
Bases:
Symbol
Special Symbol to represent an input to a ConcatSymbol.
This symbol is an augmented version of the regular Symbol to handle the interaction with the concat operation and is used to monkey-patch the original symbol.
- __init__(*args, **kwargs)
Constructor.
- property concat_sym: ConcatSymbol
Return concat symbl.
- static convert(orig_sym, cat_dim)
Modify and convert the sym in-place to a valid ConcatSymbol.Input.
- Parameters:
orig_sym (Symbol) –
cat_dim (int) –
- Return type:
- create_linked_copy()
Get a linked deepcopy of self that is monkey-patched to the original symbol class.
- Return type:
- __init__(symbols, cl_type=CLType.NONE, elastic_dims=None)
Initializes Symbol from input symbols.
- disable(_memo=None)
Disable all symbols including input symbols.
We handle input symbols by fake adding them to the dependency list. Note that the dependency list is cleared at the end anyway - so this is fine.
- Parameters:
_memo (Set[Symbol] | None) –
- Return type:
None
- property input_syms: List[Input | ConcatSymbol]
Return symbols.
- property is_constant: bool
Return indicator whether symbol is constant.
Unlike a regular symbol where this is determined based on a manually set flag, is_constant for concat is set according to whether ALL input symbols are constant.
- property is_searchable: bool
Return indicator whether symbol is searchable.
Unlike a regular symbol where this is determined based on a manually set flag, is_searchable for concat is set according to whether ANY input symbols are searchable.