nvalchemi.dynamics.FusedStage#
- class nvalchemi.dynamics.FusedStage(sub_stages, *, entry_status=0, exit_status=-1, compile_step=False, compile_kwargs=None, init_fn=None, **kwargs)[source]#
Composite dynamics engine fusing multiple sub-stages on a single GPU.
FusedStagecomposes multipleBaseDynamicssub-stages to share oneBatchand one model forward pass per step, avoiding redundant forward passes when multiple simulation phases (e.g., relaxation then MD) operate on the same batch.Unlike
BaseDynamics, ``step(batch)`` is overridden. Instead of the standardpre_update → compute → post_updateloop,FusedStageperforms: (1) a singlecompute()call on the full batch, then (2) iterates over sub-stages, applyingmasked_update(batch, mask)on each sub-stage’s dynamics for samples whosebatch.statusmatches that sub-stage’s status code. Only ONE forward pass happens per step regardless of the number of sub-stages. ``run(batch)`` is also overridden — then_stepsattribute (inherited fromBaseDynamics) and anyn_stepsargument passed torun()are both the maximum number of steps; the loop runs until all samples have migrated to theexit_status, the sampler is exhausted, orn_stepsis reached. Convergence-driven migration is handled byConvergenceHookinstances auto-registered between adjacent sub-stages: when samples converge in sub-stage i, theirbatch.statusis updated to sub-stage i+1’s code, causing them to be processed by the next dynamics on the following step. The+operator composes sub-stages:dyn_a + dyn_bcreates aFusedStage, andfused + dyn_cappends a third sub-stage. The|operator (inherited fromBaseDynamicsvia_CommunicationMixin) creates aDistributedPipelinefor multi-rank execution instead.Developers generally do NOT subclass
FusedStage. Instead, createBaseDynamicssubclasses (integrators) and compose them using+.FusedStagehandles orchestration automatically. The key requirement is that sub-stage dynamics must implementmasked_updatecorrectly (inherited fromBaseDynamics) and that the batch must have astatustensor.Hook Firing Semantics#
Because
FusedStageshares a single forward pass across all sub-stages, hook firing differs from standaloneBaseDynamicsexecution. The following hooks fire on each sub-stage during_step_impl:Fired on sub-stages (in order):
BEFORE_STEP— at the start of each fused step, before any work.AFTER_COMPUTE— after the shared model forward pass completes.BEFORE_PRE_UPDATE— before each sub-stage’smasked_update(fires even when no samples match the sub-stage’s status code).AFTER_POST_UPDATE— after each sub-stage’smasked_update(fires even when no samples match the sub-stage’s status code).AFTER_STEP— after all masked updates are complete.ON_CONVERGE— when a sub-stage’s_check_convergencedetects converged samples.
NOT fired on sub-stages:
BEFORE_COMPUTE— the forward pass is shared across all sub-stages, not executed per-sub-stage; there is no meaningful “before compute” point for individual sub-stages.AFTER_PRE_UPDATE—masked_updatecombinespre_updateandpost_updateatomically; there is no intermediate hook point.BEFORE_POST_UPDATE— same reason asAFTER_PRE_UPDATE.
Step count semantics: Each sub-stage’s
step_countis incremented alongside theFusedStage’s ownstep_countafter every fused step, ensuring that hook frequency (e.g.,every_n_steps) is respected correctly across all sub-stages.- param sub_stages:
Ordered
(status_code, dynamics)pairs. Status codes are auto-assigned starting from 0 when using the+operator.- type sub_stages:
list[tuple[int, BaseDynamics]]
- param entry_status:
Status code assigned to incoming samples (default: 0).
- type entry_status:
int
- param exit_status:
Status code that triggers graduation to the next pipeline stage. Auto-set to
len(sub_stages)(one past the last sub-stage code).- type exit_status:
int
- param compile_step:
If
True, replaceself.stepwithtorch.compile(self.step, **compile_kwargs).- type compile_step:
bool
- param compile_kwargs:
Keyword arguments forwarded to
torch.compile.- type compile_kwargs:
dict
- param **kwargs:
Additional keyword arguments forwarded to
BaseDynamics.
- sub_stages#
Ordered
(status_code, dynamics)pairs.- Type:
list[tuple[int, BaseDynamics]]
- entry_status#
Status code for incoming samples.
- Type:
int
- exit_status#
Status code that triggers graduation.
- Type:
int
- compile_step#
Whether the step method is compiled.
- Type:
bool
- compile_kwargs#
Arguments passed to
torch.compile.- Type:
dict
- __needs_keys__#
Union of all sub-stage
__needs_keys__sets. Populated automatically during__init__.- Type:
set[str]
- __provides_keys__#
Union of all sub-stage
__provides_keys__sets. Populated automatically during__init__.- Type:
set[str]
Examples
>>> from nvalchemi.dynamics import FusedStage, BaseDynamics >>> dynamics0 = BaseDynamics(model=model) >>> dynamics1 = BaseDynamics(model=model) >>> fused = FusedStage(sub_stages=[(0, dynamics0), (1, dynamics1)]) >>> fused.exit_status 2
- __init__(sub_stages, *, entry_status=0, exit_status=-1, compile_step=False, compile_kwargs=None, init_fn=None, **kwargs)[source]#
Initialize the fused stage.
- Parameters:
sub_stages (list[tuple[int, BaseDynamics]]) – Ordered
(status_code, dynamics)pairs.entry_status (int, optional) – Status code assigned to incoming samples. Default 0.
exit_status (int, optional) – Status code that triggers graduation. Auto-set to
len(sub_stages)if -1. Default -1.compile_step (bool, optional) – If
True, compile the step method withtorch.compile. DefaultFalse.compile_kwargs (dict[str, Any] | None, optional) – Keyword arguments for
torch.compile. DefaultNone.init_fn (Callable[[Batch], None] | None, optional) – Optional callback invoked on the initial batch immediately after
sampler.build_initial_batch()returns, before the first step. Use this to populate fields that the sampler does not set, such asvelocitiesorforces. Only called in Mode 2 (inflight batching withbatch=None). DefaultNone.**kwargs (Any) – Additional keyword arguments forwarded to
BaseDynamics.
- Raises:
ValueError – If sub-stages have different
device_typevalues.- Return type:
None
Methods
__init__(sub_stages, *[, entry_status, ...])Initialize the fused stage.
all_complete(batch, exit_status)Check if all samples have reached the exit status.
compile(**kwargs)Compile the fused step with
torch.compile.compute(batch)Perform the model forward pass to compute forces and energies.
masked_update(batch, mask)Apply pre_update and post_update only to selected samples in the batch.
post_update(batch)Perform the second half of the integration step.
pre_update(batch)Perform the first half of the integration step.
refill_check(batch, exit_status)Replace graduated samples via index-select and append.
register_bookkeeping_key(key, init_fn)Register a graph-level bookkeeping field to survive refill_check.
register_fused_hook(hook)Register a hook that fires at the FusedStage level on the full batch.
register_hook(hook)Register a hook to be executed at its designated stage(s).
run([batch, n_steps])Run the fused stage until all samples converge or the sampler is exhausted.
step(batch)Execute one fused step: single forward pass + masked updates.
Attributes
active_batch_has_roomReturn whether the active batch can accept more samples.
active_batch_sizeReturn the number of samples currently in the active batch.
deviceCompute the torch device for this rank.
global_rankGet the global rank for this process.
has_neighborConvenient property to see if rank is isolated
inflight_modeReturn whether inflight batching is enabled.
is_final_stageReturn whether this is the last stage in the pipeline.
is_first_stageReturn whether this is the first stage in the pipeline.
local_rankGet the node-local rank for this process.
model_is_conservativeReturns whether or not the model uses conservative forces
room_in_active_batchReturn the number of additional samples the active batch can hold.
streamReturn the active CUDA stream, if any.
- Parameters:
sub_stages (list[tuple[int, BaseDynamics]])
entry_status (int)
exit_status (int)
compile_step (bool)
compile_kwargs (dict[str, Any] | None)
init_fn (Callable[[Batch], None] | None)
kwargs (Any)