patch
Patch manager for NAS.
Classes
A standard interface to handle the monkey patching of a model. |
Functions
Calibrate model for evaluation and enable eval(). |
- class PatchManager
Bases:
ABC
A standard interface to handle the monkey patching of a model.
The class provides two main interfaces,
patch
andunpatch
, which can be used to properly overwrite methods and add new attributes to the model. They will all be removed upon callingunpatch()
.- __init__(model)
Constructor.
- Parameters:
model (Module)
- Return type:
None
- call_post_eval(forward_loop=None)
Call post-eval hook explicitly.
- Parameters:
forward_loop (Callable[[Module], None] | None) – A
Callable
that takes a model as input and runs a forward loop on it.- Return type:
None
- classmethod get_manager(model)
Return the patch manager of the model using the “correct” class.
- Parameters:
model (Module)
- Return type:
- static hooked__replicate_for_data_parallel(mod)
The _replicate_for_data_parallel method with hooks.
- Parameters:
mod (Module)
- Return type:
Module
- static hooked_forward(mod, *args, **kwargs)
The forward method with hooks.
- Parameters:
mod (Module)
- static hooked_train(mod, mode=True)
Sets the model into train or eval mode according to flag.
- Parameters:
mod (Module)
mode (bool)
- Return type:
Module
- classmethod is_patched(model)
Return whether the model is patched.
- Parameters:
model (Module)
- Return type:
bool
- patch()
Patch model in-place to be compatible with subsequent Model Optimizer tasks.
- Return type:
None
- property patch_data: dict[str, Any]
Return the patch data of the model.
- property patch_data_or_empty: dict[str, Any]
Return the patch data of the model or an empty dictionary.
- reset_before_sample()
Call reset hook before sample-related operations (sample & select).
- Return type:
None
- unpatch()
Remove and delete patching from the model.
For example:
PatchManager(model).unpatch() model.forward(x) # no patched (auto-) operations will be executed anymore.
- Return type:
None
- prep_for_eval(model, forward_loop=None)
Calibrate model for evaluation and enable eval().
- Parameters:
model (Module) – A nn.Module that might be dynamic.
forward_loop (Callable[[Module], None] | None) – A
Callable
that takes a model as input and runs a pre-defined forward loop on it using real data.
- Return type:
Module
Note
This function should only be explicitly called once after conversion when the model is immediately used for evaluation without prior training!