utils

Utils for speculative decoding.

Classes

AcceptanceRateValidation

Base acceptance rate (AR) validation class.

ResBlock

A Residual Block module.

Functions

get_default_attention_mask_and_position_ids

Compute default attention_mask ans position_ids given input_ids.

right_padding

Pad zeros to the right so that the padded_input_ids is a multiple of tp.

tree_decode

Decode tokens using the tree.

class AcceptanceRateValidation

Bases: object

Base acceptance rate (AR) validation class.

This class is used to validate the AR within ModelOpt. self.validate is the main function to validate the AR given a prompt or input_ids. Note: currently it only supports TP.

__init__(model, tokenizer, tp)

Init function to take in the model and tokenizer.

check_draft(ground_truth, input_ids, draft_tokens, tree=None)

This function checks if the draft tokens should be accepted (same as ground truth).

If tree is None, it is eager mode.

get_ground_truth(input_ids, osl)

This function returns ground truth token ids from the base model.

This function will be implemented in the plugins.

Args: input_ids (torch.Tensor): the token ids of the input attention_mask (torch.Tensor): attention mask of the input osl (int): output sequence length

tokenize(prompt)

Apply chat template to the prompt and get input_ids.

validate(osl, prompt=None, input_ids=None, ground_truth=None, tree=None, steps=1)

This function validate the AR of the model given the input sequence.

class ResBlock

Bases: Module

A Residual Block module.

This module performs a linear transformation followed by a SiLU activation, and then adds the result to the original input, creating a residual connection.

Parameters:

hidden_size (int) – The size of the hidden layers in the block.

__init__(hidden_size, bias=True)

Init function of ResBlock.

Args: hidden_size (int): The size of the hidden layers in the block.

forward(x)

Forward pass of the ResBlock.

Parameters:

x (torch.Tensor) – Input tensor.

Returns:

Output after the residual connection and activation.

Return type:

torch.Tensor

get_default_attention_mask_and_position_ids(input_ids)

Compute default attention_mask ans position_ids given input_ids.

Parameters:

input_ids (Tensor) –

right_padding(input_ids, tp, hidden_states=None)

Pad zeros to the right so that the padded_input_ids is a multiple of tp.

Parameters:
  • input_ids (Tensor) –

  • tp (int) –

  • hidden_states (Tensor) –

tree_decode(draft_logits, tree)

Decode tokens using the tree.

Parameters:
  • draft_logits (List[torch.Tensor]) – a list of logits. Each logit represent a future position.

  • tree (List[List[int]]) – a tree for decoding. Each sublist is a branch from root where the number

  • index. (represents the topk) –