core

Classes

ExternalTarget

External target for stitched modules.

FunctionTarget

FunctionTarget(name: 'str', function: 'Callable[..., Any]')

ModuleTarget

ModuleTarget(name: 'str', module: 'nn.Module')

RemoteTarget

RemoteTarget(peer_rank: 'Union[int, Sequence[int]]', process_group: 'Optional[torch.distributed.ProcessGroup]' = None, blocking: 'bool' = True)

Needle

StitchedModule

InputReducer

InputReducer(reducer_fn: 'Callable[[InputArgs, InputArgs, InputArgs, int, list[InputArgs]], InputArgs]' = <function default_input_reducer_fn at 0x7f7da177d440>)

exception CantResolveNodeDependenciesException

Bases: StitchedModuleException

class ExternalTarget

Bases: TargetWithNamedInputs, TargetWithNamedOutputs

External target for stitched modules.

__init__()
Return type:

None

class FunctionTarget

Bases: TargetWithInput, TargetWithOutput

FunctionTarget(name: ‘str’, function: ‘Callable[…, Any]’)

__init__(name, function)
Parameters:
  • name (str)

  • function (Callable[[...], Any])

Return type:

None

function: Callable[[...], Any]
name: str
class InputReducer

Bases: IOReducer

InputReducer(reducer_fn: ‘Callable[[InputArgs, InputArgs, InputArgs, int, list[InputArgs]], InputArgs]’ = <function default_input_reducer_fn at 0x7f7da177d440>)

__init__(reducer_fn=<function default_input_reducer_fn>)
Parameters:

reducer_fn (Callable[[InputArgs, InputArgs, InputArgs, int, list[InputArgs]], InputArgs])

Return type:

None

classmethod default()
Return type:

InputReducer

reducer_fn(input_override, *args)
Parameters:
exception KnotException

Bases: Exception

class ModuleTarget

Bases: TargetWithNamedInputs, TargetWithNamedOutputs

ModuleTarget(name: ‘str’, module: ‘nn.Module’)

__init__(name, module)
Parameters:
  • name (str)

  • module (Module)

Return type:

None

module: Module
name: str
exception MultipleExternalNodesException

Bases: KnotException

class Needle

Bases: object

__init__()
Return type:

None

get_node_for_target(target)
Parameters:

target (Target)

Return type:

Node

knot(capture_cache_outputs_predicate=<function always_false_predicate>, early_exit=True, ignore_extra_overrides=False)
Return type:

StitchedModule

stitch(src, dst)
Parameters:
  • src (InputDescriptor | OutputDescriptor)

  • dst (InputDescriptor | OutputDescriptor)

Return type:

Self

exception OnlyInternalNodesException

Bases: KnotException

exception OutputsLoopFoundException

Bases: LoopFoundException

class RemoteTarget

Bases: Target

RemoteTarget(peer_rank: ‘Union[int, Sequence[int]]’, process_group: ‘Optional[torch.distributed.ProcessGroup]’ = None, blocking: ‘bool’ = True)

__init__(peer_rank, process_group=None, blocking=True)
Parameters:
  • peer_rank (int | Sequence[int])

  • process_group (ProcessGroup | None)

  • blocking (bool)

Return type:

None

blocking: bool = True
peer_rank: int | Sequence[int]
process_group: ProcessGroup | None = None
value(name, adapter=<function default_output_adapter_fn>, reducer=OutputReducer(reducer_fn=<function default_output_reducer_fn>, requires_original_output=False))
Parameters:
  • name (str)

  • adapter (Callable[[...], Any])

  • reducer (OutputReducer)

Return type:

OutputDescriptor

class StitchedModule

Bases: Module

__init__(nodes, capture_cache_outputs_predicate=<function always_false_predicate>, early_exit=True, ignore_extra_overrides=False)
Parameters:
  • nodes (dict[Target, Node])

  • capture_cache_outputs_predicate (Callable[[str, Module], bool])

Return type:

None

create_input_overrides(values_to_node)
Parameters:

values_to_node (dict[InputDescriptor | OutputDescriptor, Any])

Return type:

PassageInputOverrides

create_output_overrides(values_to_node)
Parameters:

values_to_node (dict[InputDescriptor | OutputDescriptor, Any])

Return type:

PassageOutputOverrides

forward(input_overrides, output_overrides, *args, **kwargs)
Parameters:
  • input_overrides (dict[str, Any])

  • output_overrides (dict[str, Any])

Return type:

StitchedModuleOutput