patch

Patch manager for NAS.

Classes

PatchManager

A standard interface to handle the monkey patching of a model.

Functions

prep_for_eval

Calibrate model for evaluation and enable eval().

class PatchManager

Bases: ABC

A standard interface to handle the monkey patching of a model.

The class provides two main interfaces, patch and unpatch, which can be used to properly overwrite methods and add new attributes to the model. They will all be removed upon calling unpatch().

__init__(model)

Constructor.

Parameters:

model (Module)

Return type:

None

call_post_eval(forward_loop=None)

Call post-eval hook explicitly.

Parameters:

forward_loop (Callable[[Module], None] | None) – A Callable that takes a model as input and runs a forward loop on it.

Return type:

None

classmethod get_manager(model)

Return the patch manager of the model using the “correct” class.

Parameters:

model (Module)

Return type:

PatchManager

static hooked__replicate_for_data_parallel(mod)

The _replicate_for_data_parallel method with hooks.

Parameters:

mod (Module)

Return type:

Module

static hooked_forward(mod, *args, **kwargs)

The forward method with hooks.

Parameters:

mod (Module)

static hooked_train(mod, mode=True)

Sets the model into train or eval mode according to flag.

Parameters:
  • mod (Module)

  • mode (bool)

Return type:

Module

classmethod is_patched(model)

Return whether the model is patched.

Parameters:

model (Module)

Return type:

bool

patch()

Patch model in-place to be compatible with subsequent Model Optimizer tasks.

Return type:

None

property patch_data: dict[str, Any]

Return the patch data of the model.

property patch_data_or_empty: dict[str, Any]

Return the patch data of the model or an empty dictionary.

reset_before_sample()

Call reset hook before sample-related operations (sample & select).

Return type:

None

unpatch()

Remove and delete patching from the model.

For example:

PatchManager(model).unpatch()
model.forward(x)  # no patched (auto-) operations will be executed anymore.
Return type:

None

prep_for_eval(model, forward_loop=None)

Calibrate model for evaluation and enable eval().

Parameters:
  • model (Module) – A nn.Module that might be dynamic.

  • forward_loop (Callable[[Module], None] | None) – A Callable that takes a model as input and runs a pre-defined forward loop on it using real data.

Return type:

Module

Note

This function should only be explicitly called once after conversion when the model is immediately used for evaluation without prior training!