utils
Utils for speculative decoding.
Classes
Base acceptance rate (AR) validation class. |
|
A Residual Block module. |
Functions
Compute default attention_mask ans position_ids given input_ids. |
|
Pad zeros to the right so that the padded_input_ids is a multiple of tp. |
|
Decode tokens using the tree. |
- class AcceptanceRateValidation
Bases:
object
Base acceptance rate (AR) validation class.
This class is used to validate the AR within ModelOpt. self.validate is the main function to validate the AR given a prompt or input_ids. Note: currently it only supports TP.
- __init__(model, tokenizer, tp)
Init function to take in the model and tokenizer.
- check_draft(ground_truth, input_ids, draft_tokens, tree=None)
This function checks if the draft tokens should be accepted (same as ground truth).
If tree is None, it is eager mode.
- get_ground_truth(input_ids, osl)
This function returns ground truth token ids from the base model.
This function will be implemented in the plugins.
Args: input_ids (torch.Tensor): the token ids of the input attention_mask (torch.Tensor): attention mask of the input osl (int): output sequence length
- tokenize(prompt)
Apply chat template to the prompt and get input_ids.
- validate(osl, prompt=None, input_ids=None, ground_truth=None, tree=None, steps=1)
This function validate the AR of the model given the input sequence.
- class ResBlock
Bases:
Module
A Residual Block module.
This module performs a linear transformation followed by a SiLU activation, and then adds the result to the original input, creating a residual connection.
- Parameters:
hidden_size (int) – The size of the hidden layers in the block.
- __init__(hidden_size, bias=True)
Init function of ResBlock.
Args: hidden_size (int): The size of the hidden layers in the block.
- forward(x)
Forward pass of the ResBlock.
- Parameters:
x (torch.Tensor) – Input tensor.
- Returns:
Output after the residual connection and activation.
- Return type:
torch.Tensor
- get_default_attention_mask_and_position_ids(input_ids)
Compute default attention_mask ans position_ids given input_ids.
- Parameters:
input_ids (Tensor) –
- right_padding(input_ids, tp, hidden_states=None)
Pad zeros to the right so that the padded_input_ids is a multiple of tp.
- Parameters:
input_ids (Tensor) –
tp (int) –
hidden_states (Tensor) –
- tree_decode(draft_logits, tree)
Decode tokens using the tree.
- Parameters:
draft_logits (List[torch.Tensor]) – a list of logits. Each logit represent a future position.
tree (List[List[int]]) – a tree for decoding. Each sublist is a branch from root where the number
index. (represents the topk) –