conversion
Module to handle model converting and restoring for optimization methods.
When applying a model optimization algorithm, we usually need to modify the model in each step (mode) of the algorithm. This module provides the state manager, which is a standardized interface (class) to record and store state information in the model.
Op top of the state manager, this module provides utilities to save a history of these modifications (“modelopt state dict”) and restoring a unmodified model to the state indicated in the state dict.
Classes
A class to handle the modelopt state stored for each mode correspondig to a task/mode. |
Functions
Apply the provided modes the model, record the changes, and return the model. |
|
Return the modelopt state dict describing the modifications to the model. |
|
Save a model's state dict together with the modelopt state dict to restore its architecture. |
|
Restore the model architecture from the modelopt state dictionary based on the user-provided model. |
|
Load the checkpoint, restore the modelopt model modifications, and load the model's weights. |
- class ModeloptStateManager
Bases:
object
A class to handle the modelopt state stored for each mode correspondig to a task/mode.
- __init__(model=None, init_state=False)
Initialize state manager.
- Parameters:
model (Module | None) – Module that has modelopt_state stored. If None, a fake module is created to store any state that might be added with the manager.
init_state (bool) – Whether to initialize the modelopt state for the model if it does not exist.
- Return type:
None
- add_mode(mode, config, metadata)
Add mode and update state in-place.
Note that self._state is a list (preserves insertion order of keys) and we can therefore recall the order of modes!
- Parameters:
mode (_ModeDescriptor | str) –
config (ModeloptBaseConfig) –
metadata (Dict[str, Any]) –
- Return type:
None
- check_mode(mode)
Check if the proposed mode is compatible with the current state.
- Parameters:
mode (_ModeDescriptor | str) –
- Return type:
None
- static get_config_class(mode, config)
Standardize the provided config to the corresponding config class.
- Parameters:
mode (_ModeDescriptor | str) –
config (Dict[str, Any]) –
- Return type:
- property has_state: bool
Return whether the model has a non-trivial modelopt state.
- classmethod is_converted(model, is_root=False)
Check if model is converted.
- Parameters:
model (Module) – A model to be checked for state/metadata from the convert process.
is_root (bool) – Additionally check whether the module with state is the root module.
- Returns:
True if the model contains modelopt state indicating that it has been converted.
- Return type:
bool
This method raises an assertion when multiple modelopt_states are detected or when is_root is set to True but the module with state is not the root module.
- property last_mode: _ModeDescriptor | None
Return the last mode applied to the model (last stored mode).
- load_state_dict(state_dict)
Load the provided
state_dict
to the modelopt_state.- Parameters:
state_dict (List[Tuple[str, Dict[str, Dict[str, Any]]]]) –
- Return type:
None
- modes_with_states()
Yield the mode together with the full config and metadata from the state.
- Return type:
Iterator[Tuple[_ModeDescriptor, ModeloptBaseConfig, Dict[str, Any]]]
- state_dict()
Return the metadata of the model.
- Return type:
List[Tuple[str, Dict[str, Dict[str, Any]]]]
- classmethod transfer_state_dict(model_from, model_to)
Transfer the state (same instance) from one model to another.
- Parameters:
model_from (Module) –
model_to (Module) –
- Return type:
None
- update_last_state_before_new_mode(model)
Update the metadata and config of the last mode applied to the model.
- Parameters:
model (Module) –
- Return type:
None
- update_last_state_before_save(model)
Update the metadata and config of the last mode applied to the model.
- Parameters:
model (Module) –
- Return type:
None
- apply_mode(model, mode, registry=None, init_state=None)
Apply the provided modes the model, record the changes, and return the model.
- Parameters:
model (Module | Type[Module] | Tuple | Callable) – A model-like object. Can be an nn.Module, a model class type, or a tuple. Tuple must be of the form
(model_cls,)
or(model_cls, args)
or(model_cls, args, kwargs)
. Model will be initialized asmodel_cls(*args, **kwargs)
.mode (_ModeDescriptor | str | List[_ModeDescriptor | str] | List[Tuple[str, Dict[str, Any]]]) – A mode, a list of modes or a list of tuples containing the mode and its config. The mode may be specified as a string or as the actual
_ModeDescriptor
class such asQuantizeModeDescriptor
class.registry (_ModeRegistryCls | None) – An optional mode registry from which to retrieve the mode. If not provided, all registries will be searched.
init_state (bool | None) – Flag indicating whether we should initialize the state manager for the model. If not provided, it will be inferred from the model. This flag can be used to enforce a certain behavior. For example, for
init_state=True
the state manager will raise an error if the model already contains state.
- Returns:
The converted model after applying the desired modes.
- Return type:
Module
- modelopt_state(model)
Return the modelopt state dict describing the modifications to the model.
Note that the returned
modelopt_state
does not contain the model parameters such as weights and biases.modelopt_state
is useful for saving and loading various modelopt optimization states separately from the model parameters. For example:import modelopt.torch.opt as mto # Save the modelopt state and model weights separately torch.save(mto.modelopt_state(model), "modelopt_state.pt") # Save the modelopt state torch.save(model.state_dict(), "model_weights.pt") # Save the model weights
If you want to save the model weights and the modelopt state together, please use
mto.save()
.- Parameters:
model (Module) – the modelopt-modified model.
- Returns:
An modelopt state dictionary describing the modifications to the model.
- Return type:
Dict[str, Any]
- restore(model, f, **kwargs)
Load the checkpoint, restore the modelopt model modifications, and load the model’s weights.
- Parameters:
model (Module | Type[Module] | Tuple | Callable) – A model-like object. Can be an nn.Module, a model class type, or a tuple. Tuple must be of the form
(model_cls,)
or(model_cls, args)
or(model_cls, args, kwargs)
. Model will be initialized asmodel_cls(*args, **kwargs)
.f (str | PathLike | BinaryIO) – Target file location generated by
mto.save()
.**kwargs – additional args for
torch.load()
.
- Returns:
The model with original weights and stored architecture.
- Return type:
Module
Note
Note that wrappers such as DistributedDataParallel are not supported during the restore process. Please wrap the model after the restore process.
- restore_from_modelopt_state(model, modelopt_state)
Restore the model architecture from the modelopt state dictionary based on the user-provided model.
This method does not restore the model parameters such as weights and biases. Please load the weights and biases with the original checkpoint loading method after restoring modelopt states with restore_from_modelopt_state. For example:
import modelopt.torch.opt as mto model = ... # Create the model-like object # Restore the previously saved modelopt state followed by model weights mto.restore_from_modelopt_state( model, torch.load("modelopt_state.pt") ) # Restore modelopt state model.load_state_dict(torch.load("model_weights.pt"), ...) # Load the model weights
If you want to restore the model weights and the modelopt state together, please use
mto.restore()
.- Parameters:
model (Module | Type[Module] | Tuple | Callable) – A model-like object. Can be an nn.Module, a model class type, or a tuple. Tuple must be of the form
(model_cls,)
or(model_cls, args)
or(model_cls, args, kwargs)
. Model will be initialized asmodel_cls(*args, **kwargs)
.modelopt_state (Dict[str, Any]) – The modelopt state dict describing the modelopt modifications to the model. The
modelopt_state
can be generated viamto.modelopt_state()
.
- Returns:
A modified model architecture based on the restored modifications with the unmodified weights as stored in the provided
model
argument.- Return type:
Module
Note
Note that wrappers such as DistributedDataParallel are not supported during the restore process. Please wrap the model after the restore process.
- save(model, f, **kwargs)
Save a model’s state dict together with the modelopt state dict to restore its architecture.
- Parameters:
model (Module) – Any model.
f (str | PathLike | BinaryIO) – Target file location.
**kwargs – additional args for
torch.save()
.
- Return type:
None
Note
If model is a wrapper such as DistributedDataParallel, it will be unwrapped for saving.