utils

Eagle model utils.

Classes

EagleOfflineDataCollator

Data collator that truncates or pads data for offline training.

OfflineSupervisedDataset

Offline dataset for supervised fine-tuning with pre-dumped hidden states.

Functions

expand_mask

Expands attention_mask from [bsz, seq_len] to [bsz, 1, tgt_seq_len, src_seq_len].

make_causal_mask

Make causal mask used for bi-directional self-attention.

class EagleOfflineDataCollator

Bases: object

Data collator that truncates or pads data for offline training.

__init__(train_len)

Initialize with the target sequence length for truncation/padding.

class OfflineSupervisedDataset

Bases: Dataset

Offline dataset for supervised fine-tuning with pre-dumped hidden states.

This dataset loads data on-the-fly from pre-processed .pt data files generated by the compute_hidden_states_* scripts under examples/speculative_decoding/collect_hidden_states/. Each .pt file contains a dict with the following keys:

  • input_ids: token IDs of shape (seq_len,)

  • hidden_states: base model last hidden states of shape (seq_len, hidden_size)

  • aux_hidden_states: auxiliary hidden states of shape (seq_len, hidden_size)

  • loss_mask (optional): per-token assistant mask of shape (seq_len,), present when the dump was produced with --answer-only-loss.

Parameters:
  • dumped_files (list) – A list of file paths to the dumped .pt files.

  • answer_only_loss (bool) – If True, use the loss_mask stored in each .pt file so that only assistant-produced tokens contribute to the loss. Raises ValueError on __getitem__ if the file lacks loss_mask. If False (default), a uniform all-ones mask is used regardless of what is stored in the file (backward compatible).

__init__(dumped_files, answer_only_loss=False)

Initialize with a list of .pt file paths.

Parameters:

answer_only_loss (bool)

expand_mask(mask, dtype, tgt_len=None)

Expands attention_mask from [bsz, seq_len] to [bsz, 1, tgt_seq_len, src_seq_len].

Parameters:
  • mask (Tensor)

  • dtype (dtype)

  • tgt_len (int | None)

make_causal_mask(input_ids_shape, dtype, device, past_key_values_length=0)

Make causal mask used for bi-directional self-attention.

Parameters:
  • input_ids_shape (Size)

  • dtype (dtype)

  • device (device)

  • past_key_values_length (int)