core

Classes

ConstantTarget

ConstantTarget(name: 'str', value: 'Any')

ExternalTarget

External target for stitched modules.

FunctionTarget

FunctionTarget(name: 'str', function: 'Callable[..., Any]')

IOReducer

IOReducer()

InputDescriptor

InputDescriptor(target: 'Target', input_name: 'str' = '', input_adapter: 'InputAdapter' = <function default_input_adapter_fn at 0x7f9be854db20>, reducer: 'InputReducer' = <factory>)

InputReducer

InputReducer(reducer_fn: 'Callable[[InputArgs, InputArgs, InputArgs, int, list[InputArgs]], InputArgs]' = <function default_input_reducer_fn at 0x7f9be854f420>)

ModuleTarget

ModuleTarget(name: 'str', module: 'nn.Module')

Needle

Node

Node(target: 'Target', stitches_to: 'list[StitchDescriptor]' = <factory>, stitches_from: 'list[StitchDescriptor]' = <factory>)

OutputDescriptor

OutputDescriptor(target: 'Target', output_name: 'str' = '', output_adapter: 'OutputAdapter' = <function default_output_adapter_fn at 0x7f9be854dc60>, reducer: 'OutputReducer' = <factory>)

OutputReducer

OutputReducer(reducer_fn: 'Callable[[OutputValue, OutputValue, Optional[OutputValue], int, list[OutputValue]], OutputValue]' = <function default_output_reducer_fn at 0x7f9be85eb600>, requires_original_output: 'bool' = False)

RemoteDataDescriptor

RemoteDataDescriptor(key: 'str')

RemotePythonDataDescriptor

RemotePythonDataDescriptor(key: 'str', value: 'Any')

RemoteTarget

RemoteTarget(peer_rank: 'Union[int, Sequence[int]]', process_group: 'Optional[torch.distributed.ProcessGroup]' = None, blocking: 'bool' = True)

RemoteTensorDataDescriptor

RemoteTensorDataDescriptor(key: 'str', device: "Literal['cuda', 'cpu']", dtype: 'torch.dtype', shape: 'torch.Size')

Singleton

StitchDescriptor

StitchDescriptor(source_descriptor: 'IODescriptor', destination_descriptor: 'IODescriptor')

StitchedModule

StitchedModuleOutput

StitchedModuleOutput(captured_inputs: 'dict[str, InputArgs]', captured_outputs: 'dict[str, Any]')

Target

Target()

TargetWithInput

TargetWithInput()

TargetWithNamedInputs

TargetWithNamedInputs()

TargetWithNamedOutputs

TargetWithNamedOutputs()

TargetWithOutput

TargetWithOutput()

Functions

default_input_adapter_fn

default_input_reducer_fn

default_output_adapter_fn

default_output_reducer_fn

exception CantResolveNodeDependenciesException

Bases: StitchedModuleException

class ConstantTarget

Bases: TargetWithOutput

ConstantTarget(name: ‘str’, value: ‘Any’)

__init__(name, value)
Parameters:
  • name (str)

  • value (Any)

Return type:

None

name: str
value: Any
class ExternalTarget

Bases: TargetWithNamedInputs, TargetWithNamedOutputs

External target for stitched modules.

__init__()
Return type:

None

class FunctionTarget

Bases: TargetWithInput, TargetWithOutput

FunctionTarget(name: ‘str’, function: ‘Callable[…, Any]’)

__init__(name, function)
Parameters:
  • name (str)

  • function (Callable[[...], Any])

Return type:

None

function: Callable[[...], Any]
name: str
class IOReducer

Bases: object

IOReducer()

__init__()
Return type:

None

class InputDescriptor

Bases: object

InputDescriptor(target: ‘Target’, input_name: ‘str’ = ‘’, input_adapter: ‘InputAdapter’ = <function default_input_adapter_fn at 0x7f9be854db20>, reducer: ‘InputReducer’ = <factory>)

__init__(target, input_name='', input_adapter=<function default_input_adapter_fn>, reducer=<factory>)
Parameters:
Return type:

None

input_adapter()
Parameters:

input_values (InputArgs)

Return type:

InputArgs

input_name: str = ''
reducer: InputReducer
target: Target
class InputReducer

Bases: IOReducer

InputReducer(reducer_fn: ‘Callable[[InputArgs, InputArgs, InputArgs, int, list[InputArgs]], InputArgs]’ = <function default_input_reducer_fn at 0x7f9be854f420>)

__init__(reducer_fn=<function default_input_reducer_fn>)
Parameters:

reducer_fn (Callable[[InputArgs, InputArgs, InputArgs, int, list[InputArgs]], InputArgs])

Return type:

None

classmethod default()
Return type:

InputReducer

reducer_fn(input_override, *args)
Parameters:
exception InputsLoopFoundException

Bases: LoopFoundException

exception KnotException

Bases: Exception

exception LoopFoundException

Bases: KnotException

class ModuleTarget

Bases: TargetWithNamedInputs, TargetWithNamedOutputs

ModuleTarget(name: ‘str’, module: ‘nn.Module’)

__init__(name, module)
Parameters:
  • name (str)

  • module (Module)

Return type:

None

module: Module
name: str
exception MultipleExternalNodesException

Bases: KnotException

class Needle

Bases: object

__init__()
Return type:

None

get_node_for_target(target)
Parameters:

target (Target)

Return type:

Node

knot(capture_cache_outputs_predicate=<function always_false_predicate>, early_exit=True, ignore_extra_overrides=False)
Return type:

StitchedModule

stitch(src, dst)
Parameters:
Return type:

Self

class Node

Bases: object

Node(target: ‘Target’, stitches_to: ‘list[StitchDescriptor]’ = <factory>, stitches_from: ‘list[StitchDescriptor]’ = <factory>)

__init__(target, stitches_to=<factory>, stitches_from=<factory>)
Parameters:
Return type:

None

stitches_from: list[StitchDescriptor]
stitches_to: list[StitchDescriptor]
target: Target
exception OnlyInternalNodesException

Bases: KnotException

class OutputDescriptor

Bases: object

OutputDescriptor(target: ‘Target’, output_name: ‘str’ = ‘’, output_adapter: ‘OutputAdapter’ = <function default_output_adapter_fn at 0x7f9be854dc60>, reducer: ‘OutputReducer’ = <factory>)

__init__(target, output_name='', output_adapter=<function default_output_adapter_fn>, reducer=<factory>)
Parameters:
  • target (Target)

  • output_name (str)

  • output_adapter (Callable[[...], Any])

  • reducer (OutputReducer)

Return type:

None

output_adapter()
Parameters:

v (Any)

Return type:

Any

output_name: str = ''
reducer: OutputReducer
target: Target
class OutputReducer

Bases: IOReducer

OutputReducer(reducer_fn: ‘Callable[[OutputValue, OutputValue, Optional[OutputValue], int, list[OutputValue]], OutputValue]’ = <function default_output_reducer_fn at 0x7f9be85eb600>, requires_original_output: ‘bool’ = False)

__init__(reducer_fn=<function default_output_reducer_fn>, requires_original_output=False)
Parameters:
  • reducer_fn (Callable[[Any, Any, Any | None, int, list[Any]], Any])

  • requires_original_output (bool)

Return type:

None

classmethod default()
Return type:

OutputReducer

reducer_fn(input_override, *args)
Parameters:
  • acc (Any)

  • input_override (Any)

requires_original_output: bool = False
exception OutputsLoopFoundException

Bases: LoopFoundException

class RemoteDataDescriptor

Bases: ABC

RemoteDataDescriptor(key: ‘str’)

__init__(key)
Parameters:

key (str)

Return type:

None

key: str
class RemotePythonDataDescriptor

Bases: RemoteDataDescriptor

RemotePythonDataDescriptor(key: ‘str’, value: ‘Any’)

__init__(key, value)
Parameters:
  • key (str)

  • value (Any)

Return type:

None

value: Any
class RemoteTarget

Bases: Target

RemoteTarget(peer_rank: ‘Union[int, Sequence[int]]’, process_group: ‘Optional[torch.distributed.ProcessGroup]’ = None, blocking: ‘bool’ = True)

__init__(peer_rank, process_group=None, blocking=True)
Parameters:
  • peer_rank (int | Sequence[int])

  • process_group (ProcessGroup | None)

  • blocking (bool)

Return type:

None

blocking: bool = True
peer_rank: int | Sequence[int]
process_group: ProcessGroup | None = None
value(name, adapter=<function default_output_adapter_fn>, reducer=OutputReducer(reducer_fn=<function default_output_reducer_fn>, requires_original_output=False))
Parameters:
  • name (str)

  • adapter (Callable[[...], Any])

  • reducer (OutputReducer)

Return type:

OutputDescriptor

class RemoteTensorDataDescriptor

Bases: RemoteDataDescriptor

RemoteTensorDataDescriptor(key: ‘str’, device: “Literal[‘cuda’, ‘cpu’]”, dtype: ‘torch.dtype’, shape: ‘torch.Size’)

__init__(key, device, dtype, shape)
Parameters:
  • key (str)

  • device (Literal['cuda', 'cpu'])

  • dtype (dtype)

  • shape (Size)

Return type:

None

device: Literal['cuda', 'cpu']
dtype: dtype
shape: Size
class Singleton

Bases: type

class StitchDescriptor

Bases: object

StitchDescriptor(source_descriptor: ‘IODescriptor’, destination_descriptor: ‘IODescriptor’)

__init__(source_descriptor, destination_descriptor)
Parameters:
Return type:

None

destination_descriptor: InputDescriptor | OutputDescriptor
source_descriptor: InputDescriptor | OutputDescriptor
class StitchedModule

Bases: Module

__init__(nodes, capture_cache_outputs_predicate=<function always_false_predicate>, early_exit=True, ignore_extra_overrides=False)
Parameters:
  • nodes (dict[Target, Node])

  • capture_cache_outputs_predicate (Callable[[str, Module], bool])

Return type:

None

create_input_overrides(values_to_node)
Parameters:

values_to_node (dict[InputDescriptor | OutputDescriptor, Any])

Return type:

PassageInputOverrides

create_output_overrides(values_to_node)
Parameters:

values_to_node (dict[InputDescriptor | OutputDescriptor, Any])

Return type:

PassageOutputOverrides

forward(input_overrides, output_overrides, *args, **kwargs)
Parameters:
  • input_overrides (dict[str, Any])

  • output_overrides (dict[str, Any])

Return type:

StitchedModuleOutput

exception StitchedModuleException

Bases: Exception

class StitchedModuleOutput

Bases: object

StitchedModuleOutput(captured_inputs: ‘dict[str, InputArgs]’, captured_outputs: ‘dict[str, Any]’)

__init__(captured_inputs, captured_outputs)
Parameters:
  • captured_inputs (dict[str, InputArgs])

  • captured_outputs (dict[str, Any])

Return type:

None

captured_inputs: dict[str, InputArgs]
captured_outputs: dict[str, Any]
class Target

Bases: object

Target()

__init__()
Return type:

None

class TargetWithInput

Bases: Target

TargetWithInput()

__init__()
Return type:

None

input(adapter=<function default_input_adapter_fn>, reducer=InputReducer(reducer_fn=<function default_input_reducer_fn>))
Parameters:
Return type:

InputDescriptor

class TargetWithNamedInputs

Bases: Target

TargetWithNamedInputs()

__init__()
Return type:

None

input(name, adapter=<function default_input_adapter_fn>, reducer=InputReducer(reducer_fn=<function default_input_reducer_fn>))
Parameters:
Return type:

InputDescriptor

class TargetWithNamedOutputs

Bases: Target

TargetWithNamedOutputs()

__init__()
Return type:

None

output(name, adapter=<function default_output_adapter_fn>, reducer=OutputReducer(reducer_fn=<function default_output_reducer_fn>, requires_original_output=False))
Parameters:
  • name (str)

  • adapter (Callable[[...], Any])

  • reducer (OutputReducer)

Return type:

OutputDescriptor

class TargetWithOutput

Bases: Target

TargetWithOutput()

__init__()
Return type:

None

output(adapter=<function default_output_adapter_fn>, reducer=OutputReducer(reducer_fn=<function default_output_reducer_fn>, requires_original_output=False))
Parameters:
Return type:

OutputDescriptor

default_input_adapter_fn(input_values)
Parameters:

input_values (InputArgs)

Return type:

InputArgs

default_input_reducer_fn(acc, input_override, *args)
Parameters:
default_output_adapter_fn(v)
Parameters:

v (Any)

Return type:

Any

default_output_reducer_fn(acc, input_override, *args)
Parameters:
  • acc (Any)

  • input_override (Any)