conversion

Module to handle model converting and restoring for optimization methods.

When applying a model optimization algorithm, we usually need to modify the model in each step (mode) of the algorithm. This module provides the state manager, which is a standardized interface (class) to record and store state information in the model.

Op top of the state manager, this module provides utilities to save a history of these modifications (“modelopt state dict”) and restoring a unmodified model to the state indicated in the state dict.

Classes

ModeloptStateManager

A class to handle the modelopt state stored for each mode correspondig to a task/mode.

Functions

apply_mode

Apply the provided modes the model, record the changes, and return the model.

modelopt_state

Return the modelopt state dict describing the modifications to the model.

save

Save a model's state dict together with the modelopt state dict to restore its architecture.

restore_from_modelopt_state

Restore the model architecture from the modelopt state dictionary based on the user-provided model.

restore

Load the checkpoint, restore the modelopt model modifications, and load the model's weights.

class ModeloptStateManager

Bases: object

A class to handle the modelopt state stored for each mode correspondig to a task/mode.

__init__(model=None, init_state=False)

Initialize state manager.

Parameters:
  • model (Module | None) – Module that has modelopt_state stored. If None, a fake module is created to store any state that might be added with the manager.

  • init_state (bool) – Whether to initialize the modelopt state for the model if it does not exist.

Return type:

None

add_mode(mode, config, metadata)

Add mode and update state in-place.

Note that self._state is a list (preserves insertion order of keys) and we can therefore recall the order of modes!

Parameters:
  • mode (_ModeDescriptor | str) –

  • config (ModeloptBaseConfig) –

  • metadata (Dict[str, Any]) –

Return type:

None

check_mode(mode)

Check if the proposed mode is compatible with the current state.

Parameters:

mode (_ModeDescriptor | str) –

Return type:

None

static get_config_class(mode, config)

Standardize the provided config to the corresponding config class.

Parameters:
  • mode (_ModeDescriptor | str) –

  • config (Dict[str, Any]) –

Return type:

ModeloptBaseConfig

property has_state: bool

Return whether the model has a non-trivial modelopt state.

classmethod is_converted(model, is_root=False)

Check if model is converted.

Parameters:
  • model (Module) – A model to be checked for state/metadata from the convert process.

  • is_root (bool) – Additionally check whether the module with state is the root module.

Returns:

True if the model contains modelopt state indicating that it has been converted.

Return type:

bool

This method raises an assertion when multiple modelopt_states are detected or when is_root is set to True but the module with state is not the root module.

property last_mode: _ModeDescriptor | None

Return the last mode applied to the model (last stored mode).

load_state_dict(state_dict)

Load the provided state_dict to the modelopt_state.

Parameters:

state_dict (List[Tuple[str, Dict[str, Dict[str, Any]]]]) –

Return type:

None

modes_with_states()

Yield the mode together with the full config and metadata from the state.

Return type:

Iterator[Tuple[_ModeDescriptor, ModeloptBaseConfig, Dict[str, Any]]]

state_dict()

Return the metadata of the model.

Return type:

List[Tuple[str, Dict[str, Dict[str, Any]]]]

classmethod transfer_state_dict(model_from, model_to)

Transfer the state (same instance) from one model to another.

Parameters:
  • model_from (Module) –

  • model_to (Module) –

Return type:

None

update_last_state_before_new_mode(model)

Update the metadata and config of the last mode applied to the model.

Parameters:

model (Module) –

Return type:

None

update_last_state_before_save(model)

Update the metadata and config of the last mode applied to the model.

Parameters:

model (Module) –

Return type:

None

apply_mode(model, mode, registry=None, init_state=None)

Apply the provided modes the model, record the changes, and return the model.

Parameters:
  • model (Module | Type[Module] | Tuple | Callable) – A model-like object. Can be an nn.Module, a model class type, or a tuple. Tuple must be of the form (model_cls,) or (model_cls, args) or (model_cls, args, kwargs). Model will be initialized as model_cls(*args, **kwargs).

  • mode (_ModeDescriptor | str | List[_ModeDescriptor | str] | List[Tuple[str, Dict[str, Any]]]) – A mode, a list of modes or a list of tuples containing the mode and its config. The mode may be specified as a string or as the actual _ModeDescriptor class such as QuantizeModeDescriptor class.

  • registry (_ModeRegistryCls | None) – An optional mode registry from which to retrieve the mode. If not provided, all registries will be searched.

  • init_state (bool | None) – Flag indicating whether we should initialize the state manager for the model. If not provided, it will be inferred from the model. This flag can be used to enforce a certain behavior. For example, for init_state=True the state manager will raise an error if the model already contains state.

Returns:

The converted model after applying the desired modes.

Return type:

Module

modelopt_state(model)

Return the modelopt state dict describing the modifications to the model.

Note that the returned modelopt_state does not contain the model parameters such as weights and biases. modelopt_state is useful for saving and loading various modelopt optimization states separately from the model parameters. For example:

import modelopt.torch.opt as mto

# Save the modelopt state and model weights separately
torch.save(mto.modelopt_state(model), "modelopt_state.pt") # Save the modelopt state
torch.save(model.state_dict(), "model_weights.pt") # Save the model weights

If you want to save the model weights and the modelopt state together, please use mto.save().

Parameters:

model (Module) – the modelopt-modified model.

Returns:

An modelopt state dictionary describing the modifications to the model.

Return type:

Dict[str, Any]

restore(model, f, **kwargs)

Load the checkpoint, restore the modelopt model modifications, and load the model’s weights.

Parameters:
  • model (Module | Type[Module] | Tuple | Callable) – A model-like object. Can be an nn.Module, a model class type, or a tuple. Tuple must be of the form (model_cls,) or (model_cls, args) or (model_cls, args, kwargs). Model will be initialized as model_cls(*args, **kwargs).

  • f (str | PathLike | BinaryIO) – Target file location generated by mto.save().

  • **kwargs – additional args for torch.load().

Returns:

The model with original weights and stored architecture.

Return type:

Module

Note

Note that wrappers such as DistributedDataParallel are not supported during the restore process. Please wrap the model after the restore process.

restore_from_modelopt_state(model, modelopt_state)

Restore the model architecture from the modelopt state dictionary based on the user-provided model.

This method does not restore the model parameters such as weights and biases. Please load the weights and biases with the original checkpoint loading method after restoring modelopt states with restore_from_modelopt_state. For example:

import modelopt.torch.opt as mto

model = ...  # Create the model-like object

# Restore the previously saved modelopt state followed by model weights
mto.restore_from_modelopt_state(
    model, torch.load("modelopt_state.pt")
)  # Restore modelopt state
model.load_state_dict(torch.load("model_weights.pt"), ...)  # Load the model weights

If you want to restore the model weights and the modelopt state together, please use mto.restore().

Parameters:
  • model (Module | Type[Module] | Tuple | Callable) – A model-like object. Can be an nn.Module, a model class type, or a tuple. Tuple must be of the form (model_cls,) or (model_cls, args) or (model_cls, args, kwargs). Model will be initialized as model_cls(*args, **kwargs).

  • modelopt_state (Dict[str, Any]) – The modelopt state dict describing the modelopt modifications to the model. The modelopt_state can be generated via mto.modelopt_state().

Returns:

A modified model architecture based on the restored modifications with the unmodified weights as stored in the provided model argument.

Return type:

Module

Note

Note that wrappers such as DistributedDataParallel are not supported during the restore process. Please wrap the model after the restore process.

save(model, f, **kwargs)

Save a model’s state dict together with the modelopt state dict to restore its architecture.

Parameters:
  • model (Module) – Any model.

  • f (str | PathLike | BinaryIO) – Target file location.

  • **kwargs – additional args for torch.save().

Return type:

None

Note

If model is a wrapper such as DistributedDataParallel, it will be unwrapped for saving.