# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pipeline-based model composition.
:class:`PipelineModelWrapper` organizes models into **groups**, where each
group is a mini-pipeline with its own derivative computation strategy.
The top level sums outputs across groups.
Composition is available via the ``+`` operator for simple additive sums,
or via explicit ``PipelineModelWrapper`` construction for dependent
pipelines and custom derivative computation.
Motivating example — AIMNet2 + Ewald + DFTD3::
pipe = PipelineModelWrapper(groups=[
PipelineGroup(
steps=[
aimnet2,
ewald,
],
use_autograd=True,
),
PipelineGroup(steps=[dftd3]),
])
See the module docstring or the proposal for full composition examples.
"""
from __future__ import annotations
from collections import OrderedDict
from collections.abc import Callable
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import torch
from torch import nn
from nvalchemi._typing import Energy, LatticeVectors, ModelOutputs
from nvalchemi.data import AtomicData, Batch
from nvalchemi.hooks import NeighborListHook
from nvalchemi.models._ops.neighbor_filter import prepare_neighbors_for_model
from nvalchemi.models._utils import (
autograd_forces,
autograd_stresses,
prepare_strain,
sum_outputs,
)
from nvalchemi.models.base import (
BaseModelMixin,
ModelConfig,
NeighborConfig,
NeighborListFormat,
)
__all__ = ["PipelineModelWrapper", "PipelineStep", "PipelineGroup"]
# Sentinel for "attribute was not present on the object".
_MISSING = object()
# All neighbor-related attributes that may need saving/restoring when the
# pipeline temporarily adapts neighbor data for a step.
_NEIGHBOR_ATTRS = (
"neighbor_matrix",
"num_neighbors",
"neighbor_matrix_shifts",
"neighbor_list",
"edge_ptr",
"neighbor_list_shifts",
"_neighbor_list_cutoff",
)
# Type alias for the user-provided derivative function.
DerivativeFn = Callable[
[Energy, Batch, set[str]], # (energy, data, requested_keys)
dict[str, torch.Tensor], # computed derivatives
]
[docs]
@dataclass(eq=False)
class PipelineStep:
"""Wraps a model with an output rename mapping.
Only needed when a model's output key doesn't match the downstream
input key. For models that don't need renaming, pass the bare model
directly — the pipeline normalizes it internally.
Parameters
----------
model : BaseModelMixin
The model to wrap.
wire : dict[str, str]
Output-to-attribute rename mapping. Each entry
``{output_key: data_attribute}`` causes the pipeline to write the
model's ``output_key`` value onto ``data.data_attribute`` before
downstream models execute. Downstream models that declare
``data_attribute`` in their ``required_inputs`` will then receive
it automatically.
Examples
--------
AIMNet2 produces ``"charges"`` (per-atom partial charges), but the
Ewald model expects ``"node_charges"`` as a required input::
PipelineStep(aimnet2, wire={"charges": "node_charges"})
After AIMNet2 runs, the pipeline writes its ``"charges"`` output
onto ``data.node_charges``. When Ewald runs next, its
``adapt_input()`` finds ``data.node_charges`` and uses it.
If a model's output keys already match downstream input keys, no
wire mapping is needed — pass the bare model::
PipelineGroup(steps=[model_a, model_b]) # auto-wired
"""
model: BaseModelMixin
wire: dict[str, str] = field(default_factory=dict)
[docs]
@dataclass
class PipelineGroup:
"""A group of steps that share a derivative computation strategy.
Steps within a group execute in order (for wiring).
Groups execute in declaration order.
``steps`` accepts bare :class:`BaseModelMixin` instances or
:class:`PipelineStep` wrappers. Bare models are normalized to
``PipelineStep(model, wire={})`` internally.
Parameters
----------
steps : list[BaseModelMixin | PipelineStep]
Ordered list of models (or wrapped models) in this group.
use_autograd : bool
If ``True``, sub-models produce energies only; the group sums
them and calls ``derivative_fn`` to compute forces, stresses,
and any other requested derivatives from the summed energy.
If ``False`` (default), each sub-model computes its own outputs
and the group sums them directly.
derivative_fn : DerivativeFn | None
Custom derivative function called after energy summation in
autograd groups. Receives ``(energy, data, requested)`` where
``energy`` is the summed group energy (on the autograd graph),
``data`` is the batch (with ``positions.requires_grad=True``),
and ``requested`` is the set of output keys that still need to
be computed (e.g. ``{"forces", "stress"}``).
When ``None`` (default), the pipeline uses a built-in function
that computes forces as ``-dE/dr`` and stresses via the affine
strain trick (see :func:`~nvalchemi.models._utils.prepare_strain`).
Only meaningful when ``use_autograd=True``.
"""
steps: list[BaseModelMixin | PipelineStep]
use_autograd: bool = False
derivative_fn: DerivativeFn | None = None
[docs]
class PipelineModelWrapper(nn.Module, BaseModelMixin):
"""Compose multiple models via a grouped pipeline.
Models are organized into :class:`PipelineGroup` instances, where each
group has a derivative computation strategy. Within a group, steps
execute in order so that upstream outputs can wire into downstream
inputs. The pipeline sums outputs across groups using
:func:`~nvalchemi.models._utils.sum_outputs`.
The pipeline's default ``model_config.active_outputs`` is synthesized as the
**union of all sub-model** ``model_config.active_outputs`` **sets** at
construction time, so it honestly reflects what the sub-models are
configured to produce. The user can then expand or narrow it.
Parameters
----------
groups : list[PipelineGroup]
Ordered list of groups. Groups execute in declaration order.
additive_keys : set[str], optional
Keys whose values are summed across groups. Defaults to
``{"energy", "forces", "stress"}``.
Attributes
----------
model_config : ModelConfig
Mutable configuration controlling what the pipeline computes.
"""
def __init__(
self,
groups: list[PipelineGroup],
additive_keys: set[str] | None = None,
) -> None:
super().__init__()
# Normalize bare models to PipelineStep(model, wire={})
self.groups: list[PipelineGroup] = []
for group in groups:
normalized: list[PipelineStep] = []
for step in group.steps:
if isinstance(step, PipelineStep):
normalized.append(step)
else:
normalized.append(PipelineStep(model=step))
self.groups.append(
PipelineGroup(
steps=normalized,
use_autograd=group.use_autograd,
derivative_fn=group.derivative_fn,
)
)
self._models = nn.ModuleList(
s.model
for g in self.groups
for s in g.steps # type: ignore[misc]
)
self.additive_keys = additive_keys or {"energy", "forces", "stress"}
# Check wiring and collect inputs that must come from the batch.
batch_required = self._check_wiring()
# Synthesize a unified ModelConfig from all sub-models.
self.model_config = self._build_model_config(batch_required)
self._configure_sub_models()
# ------------------------------------------------------------------
# ModelConfig synthesis
# ------------------------------------------------------------------
def _build_model_config(
self, batch_required: set[str] | None = None
) -> ModelConfig:
"""Synthesize a unified :class:`ModelConfig` from all sub-model configs.
Merges capability and runtime fields across every sub-model in every
group to produce a single config that honestly represents the full
pipeline.
Parameters
----------
batch_required : set[str] | None
Required inputs that must come from the batch (not produced
by any step in the pipeline). These are added to the
pipeline's ``required_inputs``.
Synthesis rules:
- **outputs**: union of all sub-model ``outputs``. For autograd
groups, ``"forces"`` and ``"stress"`` are added because the
group can derive them from the summed energy.
- **autograd_outputs**: union of per-model ``autograd_outputs`` for
direct groups; ``{"forces", "stress"}`` for autograd groups.
- **required_inputs**: union of all sub-model ``required_inputs``.
- **active_outputs**: union of all sub-model ``active_outputs``.
- **supports_pbc**: ``True`` only if *every* sub-model supports PBC.
- **needs_pbc**: ``True`` if *any* sub-model needs PBC.
- **neighbor_config**: synthesized at the **maximum cutoff** across
all sub-models. Uses ``MATRIX`` format if any sub-model requires
it, ``COO`` otherwise.
All sub-models must agree on ``half_list``.
"""
all_outputs: set[str] = set()
all_inputs: set[str] = set()
all_autograd_outputs: set[str] = set()
default_active: set[str] = set()
needs_pbc = False
supports_pbc = True
sub_neighbor_configs: list[NeighborConfig] = []
for group in self.groups:
for step in group.steps:
cfg = step.model.model_config
all_outputs |= cfg.outputs
all_inputs |= cfg.required_inputs
default_active |= cfg.active_outputs
if group.use_autograd:
# Group-level autograd can produce forces/stresses
# from the summed energy — add them to outputs.
all_outputs |= {"forces", "stress"}
all_autograd_outputs |= {"forces", "stress"}
else:
all_autograd_outputs |= cfg.autograd_outputs
if cfg.needs_pbc:
needs_pbc = True
if not cfg.supports_pbc:
supports_pbc = False
if cfg.neighbor_config is not None:
sub_neighbor_configs.append(cfg.neighbor_config)
# Synthesize neighbor_config at max cutoff
neighbor_config: NeighborConfig | None = None
if sub_neighbor_configs:
for nc in sub_neighbor_configs:
if nc.half_list != sub_neighbor_configs[0].half_list:
raise ValueError(
"PipelineModelWrapper: sub-models have different half_list "
f"values ({nc.half_list} vs {sub_neighbor_configs[0].half_list}). "
"All sub-models must use the same half_list value."
)
max_cutoff = max(nc.cutoff for nc in sub_neighbor_configs)
has_matrix = any(
nc.format == NeighborListFormat.MATRIX for nc in sub_neighbor_configs
)
chosen_format = (
NeighborListFormat.MATRIX if has_matrix else NeighborListFormat.COO
)
skin_vals = [nc.skin for nc in sub_neighbor_configs if nc.skin is not None]
neighbor_config = NeighborConfig(
cutoff=max_cutoff,
format=chosen_format,
half_list=sub_neighbor_configs[0].half_list,
skin=max(skin_vals) if skin_vals else 0.0,
)
return ModelConfig(
outputs=frozenset(all_outputs),
autograd_outputs=frozenset(all_autograd_outputs),
required_inputs=frozenset(all_inputs | (batch_required or set())),
supports_pbc=supports_pbc,
needs_pbc=needs_pbc,
neighbor_config=neighbor_config,
active_outputs=default_active,
)
@property
def embedding_shapes(self) -> dict[str, tuple[int, ...]]:
return {}
[docs]
def compute_embeddings(
self, data: AtomicData | Batch, **kwargs: Any
) -> AtomicData | Batch:
"""Compute embeddings is not meaningful for pipeline models.
Call compute_embeddings on individual sub-models instead."""
raise NotImplementedError(
"PipelineModelWrapper does not produce unified embeddings. "
"Call compute_embeddings on individual sub-models instead."
)
[docs]
def export_model(self, path: Path, as_state_dict: bool = False) -> None:
"""Export model is not implemented for pipeline models.
Export individual sub-models instead."""
raise NotImplementedError(
"PipelineModelWrapper does not support direct export. "
"Export individual sub-models instead."
)
# ------------------------------------------------------------------
# Validation and configuration
# ------------------------------------------------------------------
def _check_wiring(self) -> set[str]:
"""Verify that the pipeline's data flow graph is satisfiable.
Walks through all groups and steps in declaration order,
accumulating the set of output keys (after wire renaming) that
each step produces. Inputs that are not produced by any prior
step become **required inputs of the pipeline** — they must be
present on the input batch at runtime.
Returns
-------
set[str]
Required inputs that must come from the batch (not produced
by any step in the pipeline).
"""
# Fields always present on a Batch — no need to wire these.
batch_fields = {
"positions",
"atomic_numbers",
"atomic_masses",
"cell",
"pbc",
"energy",
"forces",
}
available: set[str] = set(batch_fields)
batch_required: set[str] = set()
for group in self.groups:
for step in group.steps:
cfg = step.model.model_config
# Inputs not produced by prior steps must come from
# the batch — propagate them as pipeline required_inputs.
missing = set(cfg.required_inputs) - available
batch_required |= missing
# Build the effective output names (after wire renaming)
renamed_outputs: set[str] = set()
for out_key in cfg.outputs:
if out_key in step.wire:
renamed_outputs.add(step.wire[out_key])
else:
renamed_outputs.add(out_key)
available |= renamed_outputs
return batch_required
def _configure_sub_models(self) -> None:
"""Compute per-step active_output and neighbor overrides.
For autograd groups the pipeline handles forces/stress via autograd,
so sub-models should only produce energy. Rather than permanently
mutating the sub-model's ``model_config`` (which would break reuse
of the same model instance in other pipelines or standalone), we
store the overrides in ``_step_active_overrides`` and apply them
temporarily during the forward pass.
For neighbor adaptation, the pipeline's unified neighbor config may
have a larger cutoff or different format than an individual step's
model. Steps that need adaptation are flagged in
``_step_needs_neighbor_adapt`` and handled in :meth:`_call_step`.
"""
self._step_active_overrides: dict[int, set[str]] = {}
self._step_needs_neighbor_adapt: dict[int, bool] = {}
pipeline_nc = self.model_config.neighbor_config
for group in self.groups:
if group.use_autograd:
for step in group.steps:
new_active = set(step.model.model_config.active_outputs)
# Strip derivatives that the pipeline computes via
# autograd, but keep keys the model produces
# analytically (e.g. Ewald/PME with hybrid_forces=True
# returns detached kernel forces and virial).
direct = step.model.direct_derivative_keys()
new_active -= {"forces", "stress"} - direct
self._step_active_overrides[id(step)] = new_active
for step in group.steps:
step_nc = step.model.model_config.neighbor_config
if pipeline_nc is None or step_nc is None:
self._step_needs_neighbor_adapt[id(step)] = False
else:
needs = (
step_nc.format != pipeline_nc.format
or (pipeline_nc.cutoff - step_nc.cutoff) > 1e-6
)
self._step_needs_neighbor_adapt[id(step)] = needs
def _call_step(
self,
step: PipelineStep,
data: AtomicData | Batch,
**kwargs: Any,
) -> ModelOutputs:
"""Call a step's model, temporarily applying overrides.
Two kinds of temporary overrides are applied and restored:
1. **active_outputs** — for autograd groups, sub-models skip
forces/stress (computed by the group after energy summation).
2. **neighbor data** — when the pipeline's unified neighbor config
differs from the step's model (larger cutoff or different
format), the batch's neighbor tensors are swapped to the
model-specific version for the duration of the call.
"""
override = self._step_active_overrides.get(id(step))
needs_neighbor_adapt = self._step_needs_neighbor_adapt.get(id(step), False)
saved_neighbors: dict[str, Any] | None = None
saved_active: set[str] | None = None
if needs_neighbor_adapt:
saved_neighbors = self._adapt_step_neighbors(step, data)
if override is not None:
cfg = step.model.model_config
saved_active = cfg.active_outputs
cfg.active_outputs = override
try:
return step.model(data, **kwargs)
finally:
if saved_active is not None:
step.model.model_config.active_outputs = saved_active
if saved_neighbors is not None:
self._restore_step_neighbors(data, saved_neighbors)
# ------------------------------------------------------------------
# Neighbor adaptation
# ------------------------------------------------------------------
def _adapt_step_neighbors(
self,
step: PipelineStep,
data: Batch,
) -> dict[str, Any]:
"""Filter/convert neighbor data on *data* for this step's model.
Uses :func:`prepare_neighbors_for_model` to produce neighbor
tensors matching the step's ``neighbor_config`` (cutoff + format),
then writes them onto *data* so the model's ``adapt_input`` sees
the correct data without needing to call the conversion itself.
Returns the saved attribute values for :meth:`_restore_step_neighbors`.
"""
nc = step.model.model_config.neighbor_config
# nc is guaranteed non-None by _step_needs_neighbor_adapt guard.
if nc is None:
raise ValueError(
f"PipelineModelWrapper: step {step} has no neighbor config"
)
adapted = prepare_neighbors_for_model(
data, nc.cutoff, nc.format, data.num_nodes
)
# Trim MATRIX K-dimension to actual max neighbors.
if nc.format == NeighborListFormat.MATRIX and "neighbor_matrix" in adapted:
nn = adapted["num_neighbors"]
max_k = nn.max() if nn.numel() > 0 else 0
adapted["neighbor_matrix"] = adapted["neighbor_matrix"][
:, :max_k
].contiguous()
shifts = adapted.get("neighbor_matrix_shifts")
if shifts is not None:
adapted["neighbor_matrix_shifts"] = shifts[:, :max_k].contiguous()
# Save current values (check __dict__ to distinguish
# instance-level attrs from group-stored Batch properties).
saved: dict[str, Any] = {}
for attr in _NEIGHBOR_ATTRS:
if attr in data.__dict__:
saved[attr] = data.__dict__[attr]
else:
saved[attr] = _MISSING
# Write adapted tensors onto data.
for key, value in adapted.items():
data.__dict__[key] = value
# Stamp cutoff so any residual prepare_neighbors_for_model calls
# inside the model are no-ops.
data.__dict__["_neighbor_list_cutoff"] = nc.cutoff
return saved
@staticmethod
def _restore_step_neighbors(
data: Batch,
saved: dict[str, Any],
) -> None:
"""Restore neighbor data on *data* from *saved* state."""
for attr, value in saved.items():
if value is _MISSING:
# Attribute wasn't in __dict__ before — remove the shadow
# so the original group-stored value becomes visible again.
data.__dict__.pop(attr, None)
else:
data.__dict__[attr] = value
# ------------------------------------------------------------------
# Wiring
# ------------------------------------------------------------------
def _resolve_inputs(
self,
step: PipelineStep,
context: dict[PipelineStep, ModelOutputs],
data: Batch | AtomicData,
) -> None:
"""Write resolved upstream outputs onto *data* for this step's model.
For each input the model needs, check if an upstream model produced
it (via *context*). Applies wire renaming. Only writes to *data*
what this step actually needs — *data* is not polluted with all
intermediate tensors.
"""
needed = step.model.model_config.required_inputs
for ctx_step, ctx_out in context.items():
card = ctx_step.model.model_config
for out_key in card.outputs:
value = ctx_out.get(out_key)
if value is None:
continue
data_attr = ctx_step.wire.get(out_key, out_key)
if data_attr in needed:
# Use object.__setattr__ for wired intermediate
# values (e.g. charges [N]) that may not match the
# Batch system-group length validation.
object.__setattr__(data, data_attr, value)
# ------------------------------------------------------------------
# Neighbor hook factory
# ------------------------------------------------------------------
[docs]
def make_neighbor_hooks(
self, max_neighbors: int | None = None
) -> list[NeighborListHook]:
"""Return a single :class:`NeighborListHook` for the composite neighbor config.
Parameters
----------
max_neighbors : int | None, optional
Maximum neighbors per atom for MATRIX format. When ``None``
(default), auto-estimated from the cutoff at first use.
"""
from nvalchemi.dynamics.base import DynamicsStage # noqa: PLC0415
nc = self.model_config.neighbor_config
if nc is None:
return []
return [
NeighborListHook(
nc,
skin=nc.skin,
max_neighbors=max_neighbors,
stage=DynamicsStage.BEFORE_COMPUTE,
)
]
# ------------------------------------------------------------------
# Forward pass
# ------------------------------------------------------------------
[docs]
def forward(self, data: AtomicData | Batch, **kwargs: Any) -> ModelOutputs:
"""Run all sub-models and accumulate outputs.
For groups with ``use_autograd=True``, sub-models produce energies
only. The group sums them and calls the derivative function
(default or user-provided) to compute forces, stresses, and any
other requested derivatives from the summed energy.
What gets computed is driven by ``self.model_config.active_outputs``.
Parameters
----------
data : AtomicData | Batch
Input batch.
Returns
-------
ModelOutputs
Combined outputs across all groups.
"""
# Determine what derivatives are requested beyond energies.
if isinstance(data, AtomicData):
data = Batch.from_data_list([data])
requested_derivatives = self.model_config.active_outputs - {"energy"}
# Collect all autograd_inputs that need requires_grad
grad_keys: set[str] = set()
for group in self.groups:
if group.use_autograd:
for step in group.steps:
grad_keys |= step.model.model_config.autograd_inputs
else:
for step in group.steps:
card = step.model.model_config
if card.autograd_outputs & step.model.model_config.active_outputs:
grad_keys |= card.autograd_inputs
# Forward context: tracks each step's outputs without
# polluting data with all intermediate tensors.
context: dict[PipelineStep, ModelOutputs] = {}
autograd_groups = [g for g in self.groups if g.use_autograd]
group_outputs: list[ModelOutputs] = []
autograd_count = len(autograd_groups)
autograd_idx = 0
for group in self.groups:
if group.use_autograd:
group_out = self._run_autograd_group(
group,
data,
context,
requested_derivatives,
autograd_idx,
autograd_count,
grad_keys,
**kwargs,
)
autograd_idx += 1
else:
group_out = self._run_direct_group(
group,
data,
context,
**kwargs,
)
group_outputs.append(group_out)
result = sum_outputs(*group_outputs, additive_keys=self.additive_keys)
# Detach all tensors from the computation graph.
detached: ModelOutputs = OrderedDict()
for key, value in result.items():
if isinstance(value, torch.Tensor):
detached[key] = value.detach()
else:
detached[key] = value
return detached
def _run_direct_group(
self,
group: PipelineGroup,
data: AtomicData | Batch,
context: dict[PipelineStep, ModelOutputs],
**kwargs: Any,
) -> ModelOutputs:
"""Run a direct group: each model computes its own outputs, summed."""
step_outputs: list[ModelOutputs] = []
for step in group.steps:
self._resolve_inputs(step, context, data)
out = self._call_step(step, data, **kwargs)
step_outputs.append(out)
context[step] = out
return sum_outputs(*step_outputs, additive_keys=self.additive_keys)
def _run_autograd_group(
self,
group: PipelineGroup,
data: AtomicData | Batch,
context: dict[PipelineStep, ModelOutputs],
requested_derivatives: set[str],
autograd_idx: int,
autograd_count: int,
grad_keys: set[str],
**kwargs: Any,
) -> ModelOutputs:
"""Run an autograd group: sum energies, then compute derivatives.
When ``derivative_fn`` is ``None``, the pipeline uses the default
derivative computation (forces + stresses via affine strain).
When ``derivative_fn`` is provided, the user's function receives
the summed energy, the batch, and the set of requested keys.
"""
use_default_derivs = group.derivative_fn is None
need_stresses = (
use_default_derivs
and "stress" in requested_derivatives
and isinstance(data, Batch)
and hasattr(data, "cell")
and data.cell is not None
)
# Enable requires_grad on positions for force computation.
# We detach + clone first to ensure a fresh leaf tensor. Without
# this, positions from a previous step may still carry graph
# references (e.g. from in-place updates by the integrator),
# causing "backward through the graph a second time" errors.
# NOTE: This must happen BEFORE strain preparation so that
# prepare_strain can build a graph through the fresh leaves.
for key in grad_keys:
tensor = getattr(data, key, None)
if tensor is not None and isinstance(tensor, torch.Tensor):
fresh = tensor.detach().clone().requires_grad_(True)
data[key] = fresh
# Set up strain AFTER detach+clone (if stresses needed in default
# path). This scales positions and cell through a displacement
# tensor so dE/d(displacement) gives the stress. The fresh leaf
# tensors created above ensure the strain graph is not severed.
displacement = None
orig_positions = None
orig_cell = None
if need_stresses:
orig_positions = data.positions
orig_cell = data.cell
scaled_pos, scaled_cell, displacement = prepare_strain(
data.positions,
data.cell,
data.batch_idx,
)
data["positions"] = scaled_pos
data["cell"] = scaled_cell
# Run all models in the group.
step_outputs: list[ModelOutputs] = []
for step in group.steps:
self._resolve_inputs(step, context, data)
out = self._call_step(step, data, **kwargs)
step_outputs.append(out)
context[step] = out
# Sum energies across all steps in the group.
group_energy = None
for o in step_outputs:
e = o.get("energy")
if e is not None:
group_energy = e if group_energy is None else group_energy + e
needs_retain = autograd_idx < (autograd_count - 1)
group_out: ModelOutputs = OrderedDict()
if group_energy is not None:
group_out["energy"] = group_energy
# Compute derivatives from the summed energy.
if group_energy is not None and requested_derivatives:
already_produced = set(group_out.keys())
needed = requested_derivatives - already_produced
if needed:
if group.derivative_fn is not None:
# User override — full control.
derivs = group.derivative_fn(group_energy, data, needed)
else:
# Default: forces + stresses.
derivs = self._default_derivatives(
group_energy,
data,
needed,
displacement=displacement,
orig_cell=orig_cell,
retain_graph=needs_retain,
)
group_out.update(derivs)
# Sum direct additive outputs from step outputs (e.g. hybrid-force
# models that return detached kernel forces and virial/stress)
# alongside the autograd derivatives computed above. For hybrid
# electrostatic models the kernel returns dE/dR|_q (forces) and
# dE/d(strain)|_q (stress) while autograd provides the charge
# chain-rule terms (dE/dq)(dq/dR) and (dE/dq)(dq/d(strain)).
for o in step_outputs:
for key, val in o.items():
if val is not None and key in self.additive_keys and key != "energy":
if key in group_out and group_out[key] is not None:
group_out[key] = group_out[key] + val
else:
group_out[key] = val
# Carry through non-additive keys from step outputs.
for o in step_outputs:
for key, val in o.items():
if (
val is not None
and key not in self.additive_keys
and key not in group_out
):
group_out[key] = val
# Restore original positions/cell if strain was applied.
if orig_positions is not None:
data["positions"] = orig_positions
if orig_cell is not None:
data["cell"] = orig_cell
return group_out
# ------------------------------------------------------------------
# Serialization
# ------------------------------------------------------------------
[docs]
def save(self, path: str | Path) -> None:
"""Save the full pipeline (topology + model weights) to a file.
The saved file contains:
- ``"config"`` — pipeline topology (groups, wire mappings,
autograd flags, additive keys).
- ``"state_dict"`` — model weights for all sub-models.
- ``"active_outputs"`` — current ``model_config.active_outputs``.
Custom ``derivative_fn`` callables are **not** serialized. When
loading a pipeline that used a custom function, pass it again
via :meth:`load`.
Parameters
----------
path : str | Path
Destination file path.
"""
config = []
for group in self.groups:
steps_cfg = [
{
"model_class": f"{type(step.model).__module__}.{type(step.model).__qualname__}",
"wire": step.wire,
}
for step in group.steps
]
config.append(
{
"steps": steps_cfg,
"use_autograd": group.use_autograd,
"has_derivative_fn": group.derivative_fn is not None,
}
)
torch.save(
{
"config": config,
"state_dict": self.state_dict(),
"additive_keys": sorted(self.additive_keys),
"active_outputs": sorted(self.model_config.active_outputs),
},
path,
)
[docs]
@classmethod
def load(
cls,
path: str | Path,
models: list[BaseModelMixin],
derivative_fns: dict[int, DerivativeFn] | None = None,
) -> "PipelineModelWrapper":
"""Load a pipeline from a file saved with :meth:`save`.
Models must be provided in the same order they appear in the
saved config (flattened across groups). The topology (groups,
wire mappings, autograd flags) is restored from the file.
Parameters
----------
path : str | Path
Path to a file created by :meth:`save`.
models : list[BaseModelMixin]
Pre-constructed model instances, one per step in the
original pipeline (flattened across groups, in order).
derivative_fns : dict[int, DerivativeFn] | None, optional
Mapping from group index to custom derivative function.
Required for groups that were saved with
``has_derivative_fn=True``.
Returns
-------
PipelineModelWrapper
Raises
------
ValueError
If the number of models doesn't match the saved config, or
if a group requires a derivative_fn that wasn't provided.
"""
checkpoint = torch.load(path, weights_only=True)
config = checkpoint["config"]
derivative_fns = derivative_fns or {}
# Count total steps in config.
total_steps = sum(len(g["steps"]) for g in config)
if len(models) != total_steps:
raise ValueError(
f"Expected {total_steps} models (from saved config), got {len(models)}."
)
# Rebuild groups from config + provided models.
model_iter = iter(models)
groups: list[PipelineGroup] = []
for i, group_cfg in enumerate(config):
steps: list[PipelineStep] = []
for step_cfg in group_cfg["steps"]:
model = next(model_iter)
steps.append(PipelineStep(model=model, wire=step_cfg["wire"]))
dfn = derivative_fns.get(i)
if group_cfg["has_derivative_fn"] and dfn is None:
raise ValueError(
f"Group {i} requires a derivative_fn but none was "
f"provided in derivative_fns[{i}]."
)
groups.append(
PipelineGroup(
steps=steps,
use_autograd=group_cfg["use_autograd"],
derivative_fn=dfn,
)
)
additive_keys = set(checkpoint.get("additive_keys", []))
pipe = cls(groups=groups, additive_keys=additive_keys or None)
pipe.load_state_dict(checkpoint["state_dict"])
# Restore active_outputs.
saved_active = checkpoint.get("active_outputs")
if saved_active is not None:
pipe.model_config.active_outputs = set(saved_active)
return pipe
@staticmethod
def _default_derivatives(
energy: Energy,
data: Batch | AtomicData,
requested: set[str],
*,
displacement: torch.Tensor | None,
orig_cell: LatticeVectors | None,
retain_graph: bool,
) -> dict[str, torch.Tensor]:
"""Built-in derivative computation for autograd groups.
Computes forces as ``-dE/dr`` and stresses via the affine strain
trick (when ``displacement`` is provided). If neither forces nor
stresses are requested, returns an empty dict.
"""
result: dict[str, torch.Tensor] = {}
need_stresses = displacement is not None and "stress" in requested
if "forces" in requested:
result["forces"] = autograd_forces(
energy,
data.positions,
retain_graph=retain_graph or need_stresses,
)
if need_stresses:
num_graphs = data.num_graphs if isinstance(data, Batch) else 1
result["stress"] = autograd_stresses(
energy,
displacement,
orig_cell,
num_graphs,
retain_graph=retain_graph,
)
return result