nvalchemi.dynamics.DistributedPipeline#
- class nvalchemi.dynamics.DistributedPipeline(stages, synchronized=False, debug_mode=False, **dist_kwargs)[source]#
Orchestrates multi-rank pipeline execution.
Maps GPU ranks to pipeline stages and coordinates the distributed step loop. Each rank executes only its assigned stage.
- Parameters:
stages (dict[int, BaseDynamics]) – Mapping from rank to its assigned pipeline stage.
synchronized (bool) –
If
True, insert a globaldist.barrier()across all pipeline ranks after everystep()call, forcing every rank to complete its current step before any rank proceeds to the next one. This is primarily useful for debugging ordering or deadlock issues because it eliminates all inter-rank timing skew.Note
This is distinct from the per-stage
comm_modeparameter on_CommunicationMixin, which controls the blocking behavior of pairwiseisend/irecvbetween adjacent stages.synchronizedenforces a global synchronization point across the entire pipeline and will significantly reduce throughput; it should be disabled (False) in production.debug_mode (bool)
dist_kwargs (Any)
- stages#
Rank-to-stage mapping.
- Type:
dict[int, BaseDynamics]
- synchronized#
Whether a global
dist.barrier()is inserted after every step.- Type:
bool
- _dist_initialized#
Whether this DistributedPipeline instance initialized the distributed process group (used to determine cleanup responsibility).
- Type:
bool
Examples
>>> # Context manager usage (recommended): >>> pipeline = DistributedPipeline(stages={0: opt_stage, 1: md_stage}) >>> with pipeline: ... pipeline.run() ... >>> # Manual usage: >>> pipeline = DistributedPipeline(stages={0: opt_stage, 1: md_stage}) >>> pipeline.init_distributed() >>> pipeline.setup() >>> pipeline.run() >>> pipeline.cleanup() >>> # Composing multiple pipelines together >>> full_pipeline = pipe1 | pipe2 | pipe3 >>> with full_pipeline: ... pipeline.run() ...
- __init__(stages, synchronized=False, debug_mode=False, **dist_kwargs)[source]#
Initialize the pipeline.
- Parameters:
stages (dict[int, BaseDynamics]) – Mapping from global rank to pipeline stage.
synchronized (bool, optional) – If
True, insert a globaldist.barrier()across all pipeline ranks after every step, preventing any rank from advancing until all ranks have completed the current step. Useful for debugging but significantly reduces throughput. See the class-level docstring for how this differs from the per-stagecomm_mode. DefaultFalse.debug_mode (bool, optional) – When
True, emit detailedloguru.debugdiagnostics for inter-rank communication and pipeline orchestration. Propagated to all stages duringsetup(). DefaultFalse.**dist_kwargs (Any) – Additional keyword arguments for
torch.distributed.init_process_group.
- Return type:
None
Methods
__init__(stages[, synchronized, debug_mode])Initialize the pipeline.
cleanup()Destroy the
torch.distributedprocess group.init_distributed()Initialize the
torch.distributedprocess group.run()Run the pipeline loop until all stages report done.
setup()Wire up
prior_rank/next_rankbetween adjacent stages.step()Execute one timestep for the local rank's stage.
Attributes
global_rankGet the global rank for this process.
local_rankGet the local rank for this process.
local_stageGet the stage associated with the rank this is executed on.