core
Classes
External target for stitched modules. |
|
FunctionTarget(name: 'str', function: 'Callable[..., Any]') |
|
ModuleTarget(name: 'str', module: 'nn.Module') |
|
RemoteTarget(peer_rank: 'Union[int, Sequence[int]]', process_group: 'Optional[torch.distributed.ProcessGroup]' = None, blocking: 'bool' = True) |
|
InputReducer(reducer_fn: 'Callable[[InputArgs, InputArgs, InputArgs, int, list[InputArgs]], InputArgs]' = <function default_input_reducer_fn at 0x7f7da177d440>) |
- exception CantResolveNodeDependenciesException
Bases:
StitchedModuleException
- class ExternalTarget
Bases:
TargetWithNamedInputs,TargetWithNamedOutputsExternal target for stitched modules.
- __init__()
- Return type:
None
- class FunctionTarget
Bases:
TargetWithInput,TargetWithOutputFunctionTarget(name: ‘str’, function: ‘Callable[…, Any]’)
- __init__(name, function)
- Parameters:
name (str)
function (Callable[[...], Any])
- Return type:
None
- function: Callable[[...], Any]
- name: str
- class InputReducer
Bases:
IOReducerInputReducer(reducer_fn: ‘Callable[[InputArgs, InputArgs, InputArgs, int, list[InputArgs]], InputArgs]’ = <function default_input_reducer_fn at 0x7f7da177d440>)
- __init__(reducer_fn=<function default_input_reducer_fn>)
- classmethod default()
- Return type:
- exception KnotException
Bases:
Exception
- class ModuleTarget
Bases:
TargetWithNamedInputs,TargetWithNamedOutputsModuleTarget(name: ‘str’, module: ‘nn.Module’)
- __init__(name, module)
- Parameters:
name (str)
module (Module)
- Return type:
None
- module: Module
- name: str
- exception MultipleExternalNodesException
Bases:
KnotException
- class Needle
Bases:
object- __init__()
- Return type:
None
- get_node_for_target(target)
- Parameters:
target (Target)
- Return type:
Node
- knot(capture_cache_outputs_predicate=<function always_false_predicate>, early_exit=True, ignore_extra_overrides=False)
- Return type:
- stitch(src, dst)
- Parameters:
src (InputDescriptor | OutputDescriptor)
dst (InputDescriptor | OutputDescriptor)
- Return type:
Self
- exception OnlyInternalNodesException
Bases:
KnotException
- exception OutputsLoopFoundException
Bases:
LoopFoundException
- class RemoteTarget
Bases:
TargetRemoteTarget(peer_rank: ‘Union[int, Sequence[int]]’, process_group: ‘Optional[torch.distributed.ProcessGroup]’ = None, blocking: ‘bool’ = True)
- __init__(peer_rank, process_group=None, blocking=True)
- Parameters:
peer_rank (int | Sequence[int])
process_group (ProcessGroup | None)
blocking (bool)
- Return type:
None
- blocking: bool = True
- peer_rank: int | Sequence[int]
- process_group: ProcessGroup | None = None
- value(name, adapter=<function default_output_adapter_fn>, reducer=OutputReducer(reducer_fn=<function default_output_reducer_fn>, requires_original_output=False))
- Parameters:
name (str)
adapter (Callable[[...], Any])
reducer (OutputReducer)
- Return type:
OutputDescriptor
- class StitchedModule
Bases:
Module- __init__(nodes, capture_cache_outputs_predicate=<function always_false_predicate>, early_exit=True, ignore_extra_overrides=False)
- Parameters:
nodes (dict[Target, Node])
capture_cache_outputs_predicate (Callable[[str, Module], bool])
- Return type:
None
- create_input_overrides(values_to_node)
- Parameters:
values_to_node (dict[InputDescriptor | OutputDescriptor, Any])
- Return type:
- create_output_overrides(values_to_node)
- Parameters:
values_to_node (dict[InputDescriptor | OutputDescriptor, Any])
- Return type:
- forward(input_overrides, output_overrides, *args, **kwargs)
- Parameters:
input_overrides (dict[str, Any])
output_overrides (dict[str, Any])
- Return type:
StitchedModuleOutput