# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Size-aware sampler for inflight batching in dynamics simulations.
This module provides :class:`SizeAwareSampler`, which manages dataset access,
capacity budgets, and bin-packing logic for efficient GPU utilization during
dynamics simulations.
"""
from __future__ import annotations
import random
from collections import deque
from collections.abc import Iterator
from typing import Any
import torch
from torch.utils.data import Sampler
from nvalchemi.data.atomic_data import AtomicData
from nvalchemi.data.batch import Batch
[docs]
class SizeAwareSampler(Sampler[int]):
"""Size-aware sampler for inflight batching.
Manages dataset access, capacity budgets, and bin-packing logic for
efficient GPU utilization during dynamics simulations. Ensures every
replacement sample fits within the memory envelope of the graduated
sample it replaces.
When CUDA is available, the sampler uses a heuristic to estimate the
maximum number of atoms that fit in GPU memory. This estimate is combined
with user-specified ``max_atoms`` — the more restrictive constraint wins.
The GPU memory heuristic is **best-effort** and **conservative**; users
who need tighter control should set ``max_atoms`` explicitly.
Parameters
----------
dataset : Any
Dataset with ``__len__``, ``__getitem__``, and ``get_metadata(idx)``
methods. ``get_metadata`` must return ``(num_atoms, num_edges)``.
max_atoms : int | None
Maximum total atoms across all samples in a batch. ``None`` disables
the atom count constraint (GPU memory estimate may still apply).
max_edges : int | None
Maximum total edges across all samples in a batch. ``None`` disables
the edge count constraint.
max_batch_size : int
Maximum number of samples (graphs) in a batch.
bin_width : int
Atom-count bin width for grouping samples. Default 1.
shuffle : bool
Whether to shuffle within bins. Default False.
max_gpu_memory_fraction : float
Fraction of GPU memory to use when estimating atom capacity. Default
0.8 (80%), leaving 20% headroom for model parameters and CUDA context.
Only used when CUDA is available.
Raises
------
RuntimeError
If any sample in the dataset has ``num_atoms > max_atoms`` or
``num_edges > max_edges`` — such samples can never be placed into
any batch and indicate a configuration error.
ValueError
If ``max_batch_size < 1``, ``bin_width < 1``, or
``max_gpu_memory_fraction`` is not in ``(0.0, 1.0]``.
Examples
--------
>>> sampler = SizeAwareSampler(dataset, max_atoms=100, max_edges=500, max_batch_size=10)
>>> batch = sampler.build_initial_batch()
>>> replacement = sampler.request_replacement(num_atoms=5, num_edges=20)
"""
[docs]
def __init__(
self,
dataset: Any,
max_atoms: int | None,
max_edges: int | None,
max_batch_size: int,
bin_width: int = 1,
shuffle: bool = False,
max_gpu_memory_fraction: float = 0.8,
) -> None:
"""Initialize the size-aware sampler.
Parameters
----------
dataset : Any
Dataset with ``__len__``, ``__getitem__``, and ``get_metadata(idx)``
methods. ``get_metadata`` must return ``(num_atoms, num_edges)``.
max_atoms : int | None
Maximum total atoms across all samples in a batch. ``None`` disables
the atom count constraint (GPU memory estimate may still apply).
max_edges : int | None
Maximum total edges across all samples in a batch. ``None`` disables
the edge count constraint.
max_batch_size : int
Maximum number of samples (graphs) in a batch.
bin_width : int
Atom-count bin width for grouping samples. Default 1.
shuffle : bool
Whether to shuffle within bins. Default False.
max_gpu_memory_fraction : float
Fraction of GPU memory to use when estimating atom capacity. Default
0.8 (80%), leaving 20% headroom for model parameters and CUDA context.
Only used when CUDA is available.
Raises
------
RuntimeError
If any sample exceeds ``max_atoms`` or ``max_edges`` constraints.
ValueError
If ``max_batch_size < 1``, ``bin_width < 1``, or
``max_gpu_memory_fraction`` is not in ``(0.0, 1.0]``.
TypeError
If dataset does not implement required interface.
"""
# Validate parameters
if max_batch_size < 1:
raise ValueError(f"max_batch_size must be >= 1, got {max_batch_size}")
if bin_width < 1:
raise ValueError(f"bin_width must be >= 1, got {bin_width}")
if not 0.0 < max_gpu_memory_fraction <= 1.0:
raise ValueError(
f"max_gpu_memory_fraction must be in (0.0, 1.0], got {max_gpu_memory_fraction}"
)
# Runtime validation of dataset interface
if not hasattr(dataset, "__len__"):
raise TypeError("dataset must implement __len__")
if not hasattr(dataset, "__getitem__"):
raise TypeError("dataset must implement __getitem__")
if not hasattr(dataset, "get_metadata"):
raise TypeError(
"dataset must implement get_metadata(idx) -> (num_atoms, num_edges)"
)
self._dataset = dataset
self._max_atoms = max_atoms
self._max_edges = max_edges
self._max_batch_size = max_batch_size
self._bin_width = bin_width
self._shuffle = shuffle
self._max_gpu_memory_fraction = max_gpu_memory_fraction
# Pre-scan dataset and build bins
self._sample_meta: list[tuple[int, int]] = [] # (num_atoms, num_edges) per idx
self._bins: dict[int, deque[int]] = {} # bin_key -> deque of unconsumed indices
self._consumed: set[int] = set()
# GPU-resident metadata for vectorized constraint checking (lazily initialized)
self._metadata_tensor: torch.Tensor | None = None # (N, 2) int64 on device
self._consumed_mask: torch.BoolTensor | None = None # (N,) on device
# Monotonically increasing counter for stable per-system IDs.
# Stamped onto each AtomicData as a "system_id" graph-level tensor.
self._next_system_id: int = 0
self._prescan_dataset()
def _prescan_dataset(self) -> None:
"""Pre-scan all samples to extract metadata and organize into bins.
Raises
------
RuntimeError
If any sample exceeds atom or edge constraints.
"""
for idx in range(len(self._dataset)):
num_atoms, num_edges = self._dataset.get_metadata(idx)
# Validate sample fits within constraints
if self._max_atoms is not None and num_atoms > self._max_atoms:
raise RuntimeError(
f"Sample {idx} has {num_atoms} atoms, exceeding max_atoms={self._max_atoms}. "
"This sample can never fit in any batch."
)
if self._max_edges is not None and num_edges > self._max_edges:
raise RuntimeError(
f"Sample {idx} has {num_edges} edges, exceeding max_edges={self._max_edges}. "
"This sample can never fit in any batch."
)
self._sample_meta.append((num_atoms, num_edges))
# Assign to bin based on atom count
bin_key = num_atoms // self._bin_width
if bin_key not in self._bins:
self._bins[bin_key] = deque()
self._bins[bin_key].append(idx)
# Optionally shuffle within bins
if self._shuffle:
for bin_indices in self._bins.values():
random.shuffle(bin_indices)
def _estimate_max_atoms_from_gpu(self) -> int | None:
"""Estimate the maximum number of atoms that fit in GPU memory.
Uses ``torch.cuda.get_device_properties`` to query total GPU memory
and estimates the per-atom memory footprint. Returns ``None`` if CUDA
is not available.
This is a **heuristic** and **best-effort** estimate. The per-atom
memory footprint (~300 bytes) is conservative for typical MLIP
workloads. Users who need tighter control should set ``max_atoms``
explicitly.
Returns
-------
int | None
Estimated max atoms, or ``None`` if CUDA is unavailable.
"""
if not torch.cuda.is_available():
return None
props = torch.cuda.get_device_properties(torch.cuda.current_device())
total_mem = props.total_memory
usable_mem = int(total_mem * self._max_gpu_memory_fraction)
# Estimate per-atom memory: each atom needs storage for
# positions (3 * 4 bytes float32), atomic_numbers (8 bytes long),
# forces (3 * 4 bytes), velocities (3 * 4 bytes),
# atomic_masses (4 bytes), batch index (8 bytes),
# plus model hidden states (estimate ~256 bytes per atom for embeddings)
# Conservative estimate: ~300 bytes per atom
bytes_per_atom = 300
# Also account for model parameters and CUDA overhead (~20% of memory)
model_overhead = int(total_mem * 0.2)
available_for_data = max(usable_mem - model_overhead, 0)
return max(available_for_data // bytes_per_atom, 1)
def build_initial_batch(self) -> Batch:
"""Build an initial batch using greedy bin packing.
Iterates bins in ascending order (smallest atom counts first) and adds
samples that fit within all capacity constraints until no more samples
can be added. When CUDA is available, the effective ``max_atoms`` is
the minimum of the user-specified value and the GPU memory estimate.
Returns
-------
Batch
A batch with ``status`` and ``fmax`` attributes initialized.
Raises
------
RuntimeError
If no samples can be added to the batch (e.g., all consumed or
constraints too tight).
"""
# Determine effective max_atoms from user constraint and GPU estimate
gpu_max_atoms = self._estimate_max_atoms_from_gpu()
effective_max_atoms = self._max_atoms
if gpu_max_atoms is not None:
if effective_max_atoms is not None:
effective_max_atoms = min(effective_max_atoms, gpu_max_atoms)
else:
effective_max_atoms = gpu_max_atoms
data_list: list[AtomicData] = []
total_atoms = 0
total_edges = 0
# Iterate bins in ascending order (smallest first)
sorted_bin_keys = sorted(self._bins.keys())
for bin_key in sorted_bin_keys:
if bin_key not in self._bins:
continue
# Iterate through samples in this bin (lazy tombstone eviction)
bin_deque = self._bins[bin_key]
# Evict already-consumed entries from the front
while bin_deque and bin_deque[0] in self._consumed:
bin_deque.popleft()
for idx in list(bin_deque):
if idx in self._consumed:
continue
num_atoms, num_edges = self._sample_meta[idx]
# Check capacity constraints
if len(data_list) >= self._max_batch_size:
break
if (
effective_max_atoms is not None
and total_atoms + num_atoms > effective_max_atoms
):
continue
if (
self._max_edges is not None
and total_edges + num_edges > self._max_edges
):
continue
# Sample fits, load it
data, _ = self._dataset[idx]
data.add_system_property(
"system_id",
torch.tensor([[self._next_system_id]], dtype=torch.long),
)
self._next_system_id += 1
data_list.append(data)
total_atoms += num_atoms
total_edges += num_edges
self._consumed.add(idx)
# Stop if batch is full
if len(data_list) >= self._max_batch_size:
break
if not data_list:
raise RuntimeError(
"Cannot build initial batch: no samples available or constraints too tight."
)
# Create batch
batch = Batch.from_data_list(data_list, device=data_list[0].device)
# Initialize status and fmax attributes
batch["status"] = torch.zeros(
batch.num_graphs, 1, dtype=torch.long, device=batch.device
)
batch["fmax"] = torch.full(
(batch.num_graphs, 1),
float("inf"),
dtype=torch.float32,
device=batch.device,
)
return batch
def request_replacement(self, num_atoms: int, num_edges: int) -> AtomicData | None:
"""Request a replacement sample that fits within the given constraints.
Searches for an unconsumed sample with at most ``num_atoms`` atoms and
``num_edges`` edges, starting from the target bin and progressively
searching smaller bins.
Parameters
----------
num_atoms : int
Maximum number of atoms the replacement can have.
num_edges : int
Maximum number of edges the replacement can have.
Returns
-------
AtomicData | None
A replacement sample if found, or ``None`` if no suitable sample
is available.
"""
target_bin = num_atoms // self._bin_width
# Search from target bin downward (smaller sizes)
for bin_key in range(target_bin, -1, -1):
if bin_key not in self._bins:
continue
bin_deque = self._bins[bin_key]
# Lazy tombstone eviction from the front
while bin_deque and bin_deque[0] in self._consumed:
bin_deque.popleft()
for idx in list(bin_deque):
if idx in self._consumed:
continue
cand_atoms, cand_edges = self._sample_meta[idx]
# Check if candidate fits in the slot
if cand_atoms <= num_atoms and cand_edges <= num_edges:
# Found a match, load and mark consumed
data, _ = self._dataset[idx]
data.add_system_property(
"system_id",
torch.tensor([[self._next_system_id]], dtype=torch.long),
)
self._next_system_id += 1
self._consumed.add(idx)
return data
return None
@property
def exhausted(self) -> bool:
"""Check if all samples have been consumed.
Returns
-------
bool
``True`` if all bins are empty or contain only consumed indices, ``False`` otherwise.
"""
for bin_deque in self._bins.values():
for idx in bin_deque:
if idx not in self._consumed:
return False
return True
def _ensure_gpu_state(self, device: torch.device) -> None:
"""Lazily initialize GPU-resident metadata tensors for vectorized constraint checking.
Parameters
----------
device : torch.device
Device to place metadata tensors on.
"""
if self._metadata_tensor is not None and self._metadata_tensor.device == device:
return
meta = torch.tensor(
self._sample_meta, dtype=torch.long, device=device
) # (N, 2)
self._metadata_tensor = meta
self._consumed_mask = torch.zeros(
len(self._sample_meta), dtype=torch.bool, device=device
)
# Mark already-consumed indices in the GPU mask
if self._consumed:
consumed_indices = torch.tensor(
list(self._consumed), dtype=torch.long, device=device
)
self._consumed_mask[consumed_indices] = True
def request_replacements(
self,
node_counts: torch.Tensor,
edge_counts: torch.Tensor,
) -> list[AtomicData | None]:
"""Request replacement samples for multiple graduated systems using GPU-native constraint checking.
Eliminates the ``.tolist()`` D→H syncs from ``_refill_check``. Constraint
checking is fully vectorized on GPU. M scalar ``item()`` calls remain
(unavoidable: Python dataset indexing requires CPU integers).
Parameters
----------
node_counts : torch.Tensor
Shape ``(M,)`` int64 tensor on GPU. Maximum atoms each replacement can have.
edge_counts : torch.Tensor
Shape ``(M,)`` int64 tensor on GPU. Maximum edges each replacement can have.
Returns
-------
list[AtomicData | None]
Length-M list of replacement samples, or ``None`` where no suitable
sample is available.
"""
device = node_counts.device
self._ensure_gpu_state(device)
M = len(node_counts)
if self._metadata_tensor is None:
raise RuntimeError("GPU metadata tensor not initialized")
if self._consumed_mask is None:
raise RuntimeError("GPU consumed mask not initialized")
# Vectorized (M, N) fit matrix: does dataset sample j fit in graduated slot i?
fits = (
(
self._metadata_tensor[:, 0].unsqueeze(0) <= node_counts.unsqueeze(1)
) # atoms fit
& (
self._metadata_tensor[:, 1].unsqueeze(0) <= edge_counts.unsqueeze(1)
) # edges fit
& ~self._consumed_mask.unsqueeze(0) # not yet consumed
) # (M, N) bool, computed on GPU
results: list[AtomicData | None] = []
available = ~self._consumed_mask.clone() # (N,) running availability
for i in range(M):
# Candidates for this slot that are still available after prior assignments
slot_fits = fits[i] & available
candidates = slot_fits.nonzero(as_tuple=False)
if candidates.numel() == 0:
results.append(None)
continue
# ONE item() per slot — unavoidable to index into the Python dataset
chosen_idx = int(candidates[0, 0].item())
available[chosen_idx] = False
self._consumed_mask[chosen_idx] = True
# Sync CPU _consumed set for exhausted() and __len__ correctness
self._consumed.add(chosen_idx)
data, _ = self._dataset[chosen_idx]
data.add_system_property(
"system_id", torch.tensor([[self._next_system_id]], dtype=torch.long)
)
self._next_system_id += 1
results.append(data)
return results
def __iter__(self) -> Iterator[int]:
"""Yield all remaining unconsumed indices in size-grouped order.
Yields
------
int
Dataset indices in ascending bin order.
"""
sorted_bin_keys = sorted(self._bins.keys())
for bin_key in sorted_bin_keys:
if bin_key not in self._bins:
continue
for idx in self._bins[bin_key]:
if idx not in self._consumed:
yield idx
def __len__(self) -> int:
"""Return the number of unconsumed samples.
Returns
-------
int
Number of samples remaining in the sampler.
"""
return len(self._dataset) - len(self._consumed)