Speculative Decoding

Introduction

ModelOpt’s Speculative Decoding module (modelopt.torch.speculative) enables your model to generate multiple tokens in each generate step. This can be useful for reducing the latency of your model and speeds up inference.

Below are the speculative decoding algorithms supported by ModelOpt: - Medusa

Follow the steps described below to obtain a model with Medusa speculative decoding using ModelOpt’s Speculative Decoding module modelopt.torch.speculative:

  1. Convert your model via mtsp.convert: Add Medusa head to your model and enable Medusa speculative decoding in inference.

  2. Fine-tune Medusa head: Update the compute_loss function of your trainer and fine-tune the Medusa head. The base model can be frozen or fine-tuned together with Medua head.

  3. Checkpoint and re-load: Save the model via mto.save and restore via mto.restore

Convert

You can convert your model to a speculative decoding model using mtsp.convert().

Example usage:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import modelopt.torch.speculative as mtsp

# User-defined model
model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
tokenizer.pad_token_id = tokenizer.eos_token_id

# Configure and convert to medusa
medusa_config = {
    "medusa_num_heads": 4,
    "medusa_num_layers": 1,
}
medusa_model = mtsp.convert(model, [("medusa", config)])

Fine-tune Medusa model and save

After converting to a Medusa model, you need to fine-tune the Medusa head:

import os
from transformers import Trainer
import modelopt.torch.opt as mto

trainer = Trainer(model=medusa_model, tokenizer=tokenizer, args=training_args, **data_module)
trainer._move_model_to_device(medusa_model, trainer.args.device)
mtsp.plugins.transformers.replace_medusa_compute_loss(trainer, medusa_only_heads=True)

mto.enable_huggingface_checkpointing()

trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_state()
trainer.save_model("<path to the output directory>")

Note

If “medusa_only_heads” is set to True, the original model will be frozen and only the Medusa head will be fine-tuned.

To restore the saved speculative model:

# Re-initialize the original, unmodified model
model = AutoModelForCausalLM.from_pretrained("<path to the output directory>")

Speculative Decoding Concepts

Below, we will provide an overview of ModelOpt’s speculative decoding feature as well as its basic concepts and terminology.

Sepculative decoding

The standard way of generating text from a language model is with autoregressive decoding: one token is generated each step and appended to the input context for the next token generation. This means to generate K tokens it will take K serial runs of the model. Inference from large autoregressive models like Transformers can be slow and expensive. Therefore, various speculative decoding algorithms have been proposed to accelerate text generation, especially in latency critical applications.

Typically, a short draft of length K is generated using a faster, auto-regressive model, called draft model. This can be attained with either a parallel model or by calling the draft model K times. Then, a larger and more powerful model, called target model, is used to score the draft. Last, a sampling scheme is used to decide which draft to accept by the target model, recovering the distribution of the target model in the process.

Medusa algorithm

There are many ways to achieve speculative decoding. A popular approach is Medusa where instead of using an additional draft model, it introduces a few additional decoding heads to predict multiple future tokens simultaneously. During generation, these heads each produce multiple likely words for the corresponding position. These options are then combined and processed using a tree-based attention mechanism. Finally, a typical acceptance scheme is employed to pick the longest plausible prefix from the candidates for further decoding. Since the draft model is the target model itself, this guarantees the output distribution is the same as that of the target model.