utils

Eagle model utils.

Classes

EagleOfflineDataCollator

Data collator that truncates or pads data for offline training.

OfflineSupervisedDataset

Offline dataset for supervised fine-tuning with pre-dumped hidden states.

Functions

expand_mask

Expands attention_mask from [bsz, seq_len] to [bsz, 1, tgt_seq_len, src_seq_len].

make_causal_mask

Make causal mask used for bi-directional self-attention.

class EagleOfflineDataCollator

Bases: object

Data collator that truncates or pads data for offline training.

__init__(train_len)

Initialize with the target sequence length for truncation/padding.

class OfflineSupervisedDataset

Bases: Dataset

Offline dataset for supervised fine-tuning with pre-dumped hidden states.

This dataset loads data on-the-fly from pre-processed .pt data files generated by examples/speculative_decoding/main.py --mode dump_offline_data. Each .pt file contains a dict with the following keys:

  • input_ids: token IDs of shape (seq_len,)

  • hidden_states: base model last hidden states of shape (seq_len, hidden_size)

  • aux_hidden_states: auxiliary hidden states of shape (seq_len, hidden_size)

  • base_model_input_embeds: input embeddings of shape (seq_len, hidden_size)

Parameters:

dumped_files (list) – A list of file paths to the dumped .pt files.

__init__(dumped_files)

Initialize with a list of .pt file paths.

expand_mask(mask, dtype, tgt_len=None)

Expands attention_mask from [bsz, seq_len] to [bsz, 1, tgt_seq_len, src_seq_len].

Parameters:
  • mask (Tensor)

  • dtype (dtype)

  • tgt_len (int | None)

make_causal_mask(input_ids_shape, dtype, device, past_key_values_length=0)

Make causal mask used for bi-directional self-attention.

Parameters:
  • input_ids_shape (Size)

  • dtype (dtype)

  • device (device)

  • past_key_values_length (int)