BaseTensorAwareStateDict

TensorAwareStateDict defines an interface for managing various state dicts within CheckpointManager. The primary feature of this class is its ability to distinguish tensor objects from other elements. Additionally, it can be converted to and from original state_dicts.

class nvidia_resiliency_ext.checkpointing.local.base_state_dict.TensorAwareStateDict[source]

Bases: ABC

Base class that defines the interface between the user state dict and the checkpoint manager.

The primary goal is to differentiate tensor content from non-tensor content, enabling efficient migration of the state dict during checkpoint save and load.

abstract copy_tensors_to_cpu(non_blocking=False)[source]

Stores CPU copies of tensors in the state_dict, replacing the originals, but without destroying them. The original devices are remembered for restoration with restore_tensor_device().

Parameters:

non_blocking (bool) – if set to True allows for asynchronous copying.

abstract init_tensors()[source]

Initializes empty tensors with the same properties as the original tensors.

This function should only be called after the original tensors have been popped. It ensures that the newly created empty tensors match the shape, dtype, and device of the originals, but contain no data.

abstract insert_tensors(tensor_data)[source]

Reverse of pop_tensors. Replace tensor placeholders with actual values. The value of self is considered to be the same after:

self.insert_tensors(self.pop_tensors())
Parameters:

tensor_data (Iterable[Tensor]) – An iterable containing the tensor data to be inserted.

abstract property is_hollow: bool

True iff tensors had been extracted and have not been inserted back yet.

abstract pop_tensors()[source]

Extracts the tensor data from the wrapped state dict, preserving metadata.

Removes the tensor data while retaining metadata (e.g., shape, dtype, device) needed to recreate empty tensors. After this operation, the state dictionary is “hollow”, containing no tensor data. Further calls to pop_tensor will raise an error.

Returns:

List of extracted tensors

Return type:

Sequence[Tensor]

abstract restore_tensor_device(non_blocking=True)[source]

Restores all tensors to their original devices, if a move is required.

Parameters:

non_blocking (bool) – if set to True allows for asynchronous copying.

abstract property tensors: Iterable[Tensor]

Get the tensor data from the wrapped state dict.

values()[source]
Returns:

The values from the state dictionary.

Return type:

ValuesView[Any]