Source code for nvalchemi.models.composable

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Composable model composition.

:class:`ComposableModelWrapper` combines two or more
:class:`~nvalchemi.models.base.BaseModelMixin`-compatible models whose
composable outputs (energies, forces, stresses) are summed element-wise.
Non-composable outputs produced by sub-models are written back to the batch
on a last-write-wins basis.

Typical usage via the ``+`` operator (when models support it)::

    combined = model_a + model_b

Or directly::

    from nvalchemi.models.composable import ComposableModelWrapper

    combined = ComposableModelWrapper(lj_model, mlip_model)
    combined.model_config.compute_stresses = True

The composite model synthesises a :class:`~nvalchemi.models.base.ModelCard`
from all sub-model cards, picking the most permissive neighbor configuration
(maximum cutoff, full list, MATRIX if any sub-model uses MATRIX).
"""

from __future__ import annotations

import warnings
from collections import OrderedDict
from pathlib import Path
from typing import TYPE_CHECKING, Any

from torch import Tensor, nn

from nvalchemi.data import AtomicData, Batch
from nvalchemi.models.base import (
    BaseModelMixin,
    ModelCard,
    ModelConfig,
    NeighborConfig,
    NeighborListFormat,
)

if TYPE_CHECKING:
    from nvalchemi.dynamics.hooks import NeighborListHook

__all__ = ["ComposableModelWrapper"]

_COMPOSABLE_KEYS: frozenset[str] = frozenset({"energies", "forces", "stresses"})


[docs] class ComposableModelWrapper(nn.Module, BaseModelMixin): """Compose multiple models by summing their composable outputs. Parameters ---------- *models : BaseModelMixin Two or more model wrappers to compose. Any nested :class:`ComposableModelWrapper` instances are flattened into the top-level list so that the composition is always a single flat layer. Attributes ---------- models : nn.ModuleList Flat list of constituent model wrappers. """ def __init__(self, *models: BaseModelMixin) -> None: super().__init__() # Flatten any nested ComposableModelWrappers. flat: list[BaseModelMixin] = [] for m in models: if isinstance(m, ComposableModelWrapper): flat.extend(list(m.models)) else: flat.append(m) # Guard: energy-first composition for multiple autograd models is not yet # implemented. Summing pre-computed forces is memory-inefficient for two # autograd models. n_autograd = sum( 1 for m in flat if getattr(m, "model_card", None) is not None and m.model_card.forces_via_autograd # type: ignore[union-attr] ) if n_autograd > 1: raise NotImplementedError( "Composing two or more autograd-forces models is not yet supported. " "Energy-first composition (sum energies, single autograd pass) is " "required for memory correctness but not yet implemented. " ) self.models: nn.ModuleList = nn.ModuleList(flat) # type: ignore[arg-type] # Use the property setter so all sub-models share the same ModelConfig # instance from construction; in-place mutations (e.g. # wrapper.model_config.compute_stresses = True) then propagate # automatically because every sub-model holds a reference to the same # object. self.model_config = ModelConfig() self._model_card: ModelCard = self._build_model_card() # ------------------------------------------------------------------ # model_config property (shadows class-level attr from BaseModelMixin) # ------------------------------------------------------------------ @property def model_config(self) -> ModelConfig: # type: ignore[override] """Mutable configuration controlling which outputs are computed.""" return self._model_config @model_config.setter def model_config(self, config: ModelConfig) -> None: self._model_config = config for model in self.models: model.model_config = config # type: ignore[assignment] # ------------------------------------------------------------------ # BaseModelMixin required properties # ------------------------------------------------------------------ def _build_model_card(self) -> ModelCard: cards = [m.model_card for m in self.models] # type: ignore[union-attr] forces_via_autograd = any(c.forces_via_autograd for c in cards) supports_energies = all(c.supports_energies for c in cards) supports_forces = all(c.supports_forces for c in cards) supports_stresses = all(c.supports_stresses for c in cards) supports_pbc = all(c.supports_pbc for c in cards) needs_pbc = any(c.needs_pbc for c in cards) supports_non_batch = all(c.supports_non_batch for c in cards) # Synthesise neighbor_config from sub-models that have one. sub_configs = [ c.neighbor_config for c in cards if c.neighbor_config is not None ] if sub_configs: # Validate that all sub-models have the same half_list value. for nc in sub_configs: if nc.half_list != sub_configs[0].half_list: raise ValueError( "ComposableModelWrapper: a sub-model has a different half_list value in its " "NeighborConfig. All sub-models must use the same half_list value when " f"composed. Got {nc.half_list} and {sub_configs[0].half_list}." ) half_list = sub_configs[0].half_list max_cutoff = max(nc.cutoff for nc in sub_configs) has_matrix = any( nc.format == NeighborListFormat.MATRIX for nc in sub_configs ) chosen_format = ( NeighborListFormat.MATRIX if has_matrix else NeighborListFormat.COO ) max_neighbors_vals = [ nc.max_neighbors for nc in sub_configs if nc.max_neighbors is not None ] max_neighbors = max(max_neighbors_vals) if max_neighbors_vals else None neighbor_config: NeighborConfig | None = NeighborConfig( cutoff=max_cutoff, format=chosen_format, half_list=half_list, max_neighbors=max_neighbors, ) else: neighbor_config = None # Warn if two sub-models both claim to include the same physics term, # which would cause double-counting in the composable composition. n_dispersion = sum(1 for c in cards if c.includes_dispersion) if n_dispersion > 1: warnings.warn( "ComposableModelWrapper: two or more sub-models have includes_dispersion=True. " "This may double-count dispersion interactions. Verify your model checkpoints.", UserWarning, stacklevel=3, ) n_elec = sum(1 for c in cards if c.includes_long_range_electrostatics) if n_elec > 1: warnings.warn( "ComposableModelWrapper: two or more sub-models have " "includes_long_range_electrostatics=True. " "This may double-count long-range electrostatic interactions.", UserWarning, stacklevel=3, ) return ModelCard( forces_via_autograd=forces_via_autograd, supports_energies=supports_energies, supports_forces=supports_forces, supports_stresses=supports_stresses, supports_pbc=supports_pbc, needs_pbc=needs_pbc, supports_non_batch=supports_non_batch, neighbor_config=neighbor_config, ) @property def model_card(self) -> ModelCard: """Synthesised :class:`ModelCard` derived from all sub-model cards.""" return self._model_card @property def embedding_shapes(self) -> dict[str, tuple[int, ...]]: # Composite models do not have a unified embedding space. return {} # ------------------------------------------------------------------ # Methods that are not meaningful for composite models # ------------------------------------------------------------------
[docs] def compute_embeddings( self, data: AtomicData | Batch, **kwargs: Any ) -> AtomicData | Batch: """Compute embeddings is not meaningful for composite models. Call compute_embeddings on individual sub-models instead.""" raise NotImplementedError( "ComposableModelWrapper does not produce unified embeddings. " "Call compute_embeddings on individual sub-models instead." )
[docs] def export_model(self, path: Path, as_state_dict: bool = False) -> None: """Export model is not implemented for composite models. Export individual sub-models instead.""" raise NotImplementedError( "ComposableModelWrapper does not support direct export. " "Export individual sub-models instead." )
# ------------------------------------------------------------------ # Neighbor hook factory # ------------------------------------------------------------------
[docs] def make_neighbor_hooks(self) -> list[NeighborListHook]: """Return a single :class:`NeighborListHook` for the composite neighbor config. A single composite hook at the maximum cutoff is used for all sub-models. This avoids running multiple neighbor-list algorithms per dynamics step. The cost of reformatting (e.g. neighbor matrix → neighbor list) at a synchronization point is preferable to computing separate neighbor lists. The import is deferred to avoid circular imports between ``nvalchemi.models`` and ``nvalchemi.dynamics``. """ from nvalchemi.dynamics.hooks import NeighborListHook nc = self.model_card.neighbor_config if nc is None: return [] return [NeighborListHook(nc)]
# ------------------------------------------------------------------ # Forward pass # ------------------------------------------------------------------
[docs] def forward(self, data: Batch, **kwargs: Any) -> OrderedDict[str, Tensor]: """Run all sub-models left-to-right and accumulate composable outputs. Composable outputs (``"energies"``, ``"forces"``, ``"stress"``) are summed across models. All other outputs are written back to *data* on a last-write-wins basis. Parameters ---------- data : Batch Input batch. Neighbor data must already be populated (e.g. by a :class:`~nvalchemi.dynamics.hooks.NeighborListHook`). Returns ------- OrderedDict[str, Tensor] Accumulated composable outputs in canonical order (energies → forces → stress), containing only the keys that are present in at least one sub-model's output. """ accumulated: dict[str, Tensor] = {} for model in self.models: result = model(data, **kwargs) # type: ignore[operator] if result is None: continue for key, val in result.items(): if val is None: continue if key in _COMPOSABLE_KEYS: if key in accumulated: accumulated[key] = accumulated[key] + val else: accumulated[key] = val else: # Non-additive: write back to batch (last-write-wins). # Use object.__setattr__ to bypass Batch's custom __setattr__ # which tries to route tensors into data groups and requires # tensors to have a well-defined len(). object.__setattr__(data, key, val) # Return in canonical key order. out: OrderedDict[str, Tensor] = OrderedDict() for key in ("energies", "forces", "stresses"): if key in accumulated: out[key] = accumulated[key] return out