speculative_decoding
User-facing API for converting a model into a modelopt.torch.speculative.MedusaModel.
Functions
Main conversion function to turn a base model into a speculative decoding model. |
- convert(model, mode)
Main conversion function to turn a base model into a speculative decoding model.
- Parameters:
model (Module) – The base model to be used.
mode (_ModeDescriptor | str | List[_ModeDescriptor | str] | List[Tuple[str, Dict[str, Any]]]) –
A (list of) string(s) or Mode(s) or a list of tuples containing the mode and its config indicating the desired mode(s) (and configurations) for the convert process. Modes set up the model for different algorithms for model optimization. The following modes are available:
"medusa"
: Themodel
will be converted into a medusa model with added medusa head. The mode’s config is described inMedusaConfig
."eagle"
: Themodel
will be converted into a eagle model with added eagle weights. The mode’s config is described inEagleConfig
.
If the mode argument is specified as a dictionary, the keys should indicate the mode and the values specify the per-mode configuration.
- Returns:
An instance of
MedusaModel <modelopt.torch.distill.MedusaModel
orEagleModel <modelopt.torch.distill.EagleModel
its subclass.- Return type:
Module