nvalchemi.dynamics.FusedStage#

class nvalchemi.dynamics.FusedStage(sub_stages, *, entry_status=0, exit_status=-1, compile_step=False, compile_kwargs=None, init_fn=None, **kwargs)[source]#

Composite dynamics engine fusing multiple sub-stages on a single GPU.

FusedStage composes multiple BaseDynamics sub-stages to share one Batch and one model forward pass per step, avoiding redundant forward passes when multiple simulation phases (e.g., relaxation then MD) operate on the same batch.

Unlike BaseDynamics, ``step(batch)`` is overridden. Instead of the standard pre_update compute post_update loop, FusedStage performs: (1) a single compute() call on the full batch, then (2) iterates over sub-stages, applying masked_update(batch, mask) on each sub-stage’s dynamics for samples whose batch.status matches that sub-stage’s status code. Only ONE forward pass happens per step regardless of the number of sub-stages. ``run(batch)`` is also overridden — the n_steps attribute (inherited from BaseDynamics) and any n_steps argument passed to run() are both the maximum number of steps; the loop runs until all samples have migrated to the exit_status, the sampler is exhausted, or n_steps is reached. Convergence-driven migration is handled by ConvergenceHook instances auto-registered between adjacent sub-stages: when samples converge in sub-stage i, their batch.status is updated to sub-stage i+1’s code, causing them to be processed by the next dynamics on the following step. The + operator composes sub-stages: dyn_a + dyn_b creates a FusedStage, and fused + dyn_c appends a third sub-stage. The | operator (inherited from BaseDynamics via _CommunicationMixin) creates a DistributedPipeline for multi-rank execution instead.

Developers generally do NOT subclass FusedStage. Instead, create BaseDynamics subclasses (integrators) and compose them using +. FusedStage handles orchestration automatically. The key requirement is that sub-stage dynamics must implement masked_update correctly (inherited from BaseDynamics) and that the batch must have a status tensor.

Hook Firing Semantics#

Because FusedStage shares a single forward pass across all sub-stages, hook firing differs from standalone BaseDynamics execution. The following hooks fire on each sub-stage during _step_impl:

Fired on sub-stages (in order):

  • BEFORE_STEP — at the start of each fused step, before any work.

  • AFTER_COMPUTE — after the shared model forward pass completes.

  • BEFORE_PRE_UPDATE — before each sub-stage’s masked_update (fires even when no samples match the sub-stage’s status code).

  • AFTER_POST_UPDATE — after each sub-stage’s masked_update (fires even when no samples match the sub-stage’s status code).

  • AFTER_STEP — after all masked updates are complete.

  • ON_CONVERGE — when a sub-stage’s _check_convergence detects converged samples.

NOT fired on sub-stages:

  • BEFORE_COMPUTE — the forward pass is shared across all sub-stages, not executed per-sub-stage; there is no meaningful “before compute” point for individual sub-stages.

  • AFTER_PRE_UPDATEmasked_update combines pre_update and post_update atomically; there is no intermediate hook point.

  • BEFORE_POST_UPDATE — same reason as AFTER_PRE_UPDATE.

Step count semantics: Each sub-stage’s step_count is incremented alongside the FusedStage’s own step_count after every fused step, ensuring that hook frequency (e.g., every_n_steps) is respected correctly across all sub-stages.

param sub_stages:

Ordered (status_code, dynamics) pairs. Status codes are auto-assigned starting from 0 when using the + operator.

type sub_stages:

list[tuple[int, BaseDynamics]]

param entry_status:

Status code assigned to incoming samples (default: 0).

type entry_status:

int

param exit_status:

Status code that triggers graduation to the next pipeline stage. Auto-set to len(sub_stages) (one past the last sub-stage code).

type exit_status:

int

param compile_step:

If True, replace self.step with torch.compile(self.step, **compile_kwargs).

type compile_step:

bool

param compile_kwargs:

Keyword arguments forwarded to torch.compile.

type compile_kwargs:

dict

param **kwargs:

Additional keyword arguments forwarded to BaseDynamics.

sub_stages#

Ordered (status_code, dynamics) pairs.

Type:

list[tuple[int, BaseDynamics]]

entry_status#

Status code for incoming samples.

Type:

int

exit_status#

Status code that triggers graduation.

Type:

int

compile_step#

Whether the step method is compiled.

Type:

bool

compile_kwargs#

Arguments passed to torch.compile.

Type:

dict

__needs_keys__#

Union of all sub-stage __needs_keys__ sets. Populated automatically during __init__.

Type:

set[str]

__provides_keys__#

Union of all sub-stage __provides_keys__ sets. Populated automatically during __init__.

Type:

set[str]

Examples

>>> from nvalchemi.dynamics import FusedStage, BaseDynamics
>>> dynamics0 = BaseDynamics(model=model)
>>> dynamics1 = BaseDynamics(model=model)
>>> fused = FusedStage(sub_stages=[(0, dynamics0), (1, dynamics1)])
>>> fused.exit_status
2
__init__(sub_stages, *, entry_status=0, exit_status=-1, compile_step=False, compile_kwargs=None, init_fn=None, **kwargs)[source]#

Initialize the fused stage.

Parameters:
  • sub_stages (list[tuple[int, BaseDynamics]]) – Ordered (status_code, dynamics) pairs.

  • entry_status (int, optional) – Status code assigned to incoming samples. Default 0.

  • exit_status (int, optional) – Status code that triggers graduation. Auto-set to len(sub_stages) if -1. Default -1.

  • compile_step (bool, optional) – If True, compile the step method with torch.compile. Default False.

  • compile_kwargs (dict[str, Any] | None, optional) – Keyword arguments for torch.compile. Default None.

  • init_fn (Callable[[Batch], None] | None, optional) – Optional callback invoked on the initial batch immediately after sampler.build_initial_batch() returns, before the first step. Use this to populate fields that the sampler does not set, such as velocities or forces. Only called in Mode 2 (inflight batching with batch=None). Default None.

  • **kwargs (Any) – Additional keyword arguments forwarded to BaseDynamics.

Raises:

ValueError – If sub-stages have different device_type values.

Return type:

None

Methods

__init__(sub_stages, *[, entry_status, ...])

Initialize the fused stage.

all_complete(batch, exit_status)

Check if all samples have reached the exit status.

compile(**kwargs)

Compile the fused step with torch.compile.

compute(batch)

Perform the model forward pass to compute forces and energies.

masked_update(batch, mask)

Apply pre_update and post_update only to selected samples in the batch.

post_update(batch)

Perform the second half of the integration step.

pre_update(batch)

Perform the first half of the integration step.

refill_check(batch, exit_status)

Replace graduated samples via index-select and append.

register_bookkeeping_key(key, init_fn)

Register a graph-level bookkeeping field to survive refill_check.

register_fused_hook(hook)

Register a hook that fires at the FusedStage level on the full batch.

register_hook(hook)

Register a hook to be executed at its designated stage(s).

run([batch, n_steps])

Run the fused stage until all samples converge or the sampler is exhausted.

step(batch)

Execute one fused step: single forward pass + masked updates.

Attributes

active_batch_has_room

Return whether the active batch can accept more samples.

active_batch_size

Return the number of samples currently in the active batch.

device

Compute the torch device for this rank.

global_rank

Get the global rank for this process.

has_neighbor

Convenient property to see if rank is isolated

inflight_mode

Return whether inflight batching is enabled.

is_final_stage

Return whether this is the last stage in the pipeline.

is_first_stage

Return whether this is the first stage in the pipeline.

local_rank

Get the node-local rank for this process.

model_is_conservative

Returns whether or not the model uses conservative forces

room_in_active_batch

Return the number of additional samples the active batch can hold.

stream

Return the active CUDA stream, if any.

Parameters:
  • sub_stages (list[tuple[int, BaseDynamics]])

  • entry_status (int)

  • exit_status (int)

  • compile_step (bool)

  • compile_kwargs (dict[str, Any] | None)

  • init_fn (Callable[[Batch], None] | None)

  • kwargs (Any)