utils

Utils for speculative decoding.

Classes

AcceptanceRateValidation

Base acceptance rate (AR) validation class.

ResBlock

A Residual Block module.

Tree

A tree structure for speculative decoding that defines valid token prediction paths.

TreeNode

A node in the speculative decoding tree structure.

Functions

calibrate_frequent_vocab

Given a calibration text, find the most common vocabs and return the mapping.

get_default_attention_mask_and_position_ids

Compute default attention_mask ans position_ids given input_ids.

class AcceptanceRateValidation

Bases: object

Base acceptance rate (AR) validation class.

This class is used to validate the AR within ModelOpt. self.validate is the main function to validate the AR given a prompt or input_ids. Note: currently it only supports TP.

__init__(model, tokenizer)

Init function to take in the model and tokenizer.

check_data_consistancy_across_ranks(data, group=None, fail_when_mismatch=True)

This function checks the data consistancy across all ranks in the group.

Use rank 0 data as the golden set to broadcast to all ranks. Each rank will then compare to this data and through error if different.

check_draft(ground_truth, input_ids, draft_tokens)

This function checks if the draft tokens should be accepted (same as ground truth).

Parameters:
  • ground_truth – the ground truth token ids

  • input_ids – the input token ids

  • draft_tokens – the draft tokens

Returns:

the updated input token ids

Return type:

input_ids

get_ground_truth(input_ids, osl)

This function returns ground truth token ids from the base model.

This function will be implemented in the plugins.

Parameters:
  • input_ids (Tensor) – the token ids of the input

  • osl (int) – output sequence length

tokenize(prompt)

Apply chat template to the prompt and get input_ids.

validate(osl, prompt=None, input_ids=None, ground_truth=None, steps=1, tree_paths=None)

This function validate the AR of the model given the input sequence.

class ResBlock

Bases: Module

A Residual Block module.

This module performs a linear transformation followed by a SiLU activation, and then adds the result to the original input, creating a residual connection.

__init__(hidden_size, bias=True)

Init function of ResBlock.

Parameters:
  • hidden_size (int) – The size of the hidden layers in the block.

  • bias (bool)

forward(x)

Forward pass of the ResBlock.

Parameters:

x (Tensor) – Input tensor.

Returns:

Output after the residual connection and activation.

Return type:

Tensor

class Tree

Bases: object

A tree structure for speculative decoding that defines valid token prediction paths.

This class implements a tree-based structure used in speculative decoding to represent multiple possible token prediction paths. The tree is constructed from a list of paths, where each path is a sequence of token positions.

__init__(tree_paths)

Initialize a Tree.

Parameters:

tree_paths (list[list[int]]) – a list of tree paths

create_attention_mask()

Create the attention mask for the tree.

This function constructs the attention mask for the tree based on the tree structure. It ensures that each token can only attend to its valid predecessors according to the tree.

create_tree(tree_paths)

Create the tree structure from the list of tree paths.

This function builds the tree by iterating through each path in the tree_paths list. For each path, it traverses the tree, creating nodes and updating the number of children at each level.

class TreeNode

Bases: object

A node in the speculative decoding tree structure.

Each node represents a token position in the sequence and maintains a dictionary of child nodes,

__init__(value, children=None)

Initialize a TreeNode.

Parameters:
  • value (int) – the value of the node

  • children (dict) – a dictionary of children nodes

calibrate_frequent_vocab(tokenizer, text, target_vocab_size, output_file=None)

Given a calibration text, find the most common vocabs and return the mapping.

get_default_attention_mask_and_position_ids(input_ids)

Compute default attention_mask ans position_ids given input_ids.

Parameters:

input_ids (Tensor)