speculative_decoding

User-facing API for converting a model into a modelopt.torch.speculative.MedusaModel.

Functions

convert

Main conversion function to turn a base model into a speculative decoding model.

convert(model, mode)

Main conversion function to turn a base model into a speculative decoding model.

Parameters:
  • model (Module) – The base model to be used.

  • mode (_ModeDescriptor | str | List[_ModeDescriptor | str] | List[Tuple[str, Dict[str, Any]]]) –

    A (list of) string(s) or Mode(s) or a list of tuples containing the mode and its config indicating the desired mode(s) (and configurations) for the convert process. Modes set up the model for different algorithms for model optimization. The following modes are available:

    • "medusa": The model will be converted into a medusa model with added medusa head. The mode’s config is described in MedusaConfig.

    If the mode argument is specified as a dictionary, the keys should indicate the mode and the values specify the per-mode configuration.

Returns:

An instance of MedusaModel <modelopt.torch.distill.MedusaModel or its subclass.

Return type:

Module