utils
Eagle model utils.
Classes
Data collator that truncates or pads data for offline training. |
|
Offline dataset for supervised fine-tuning with pre-dumped hidden states. |
Functions
Expands attention_mask from [bsz, seq_len] to [bsz, 1, tgt_seq_len, src_seq_len]. |
|
Make causal mask used for bi-directional self-attention. |
- class EagleOfflineDataCollator
Bases:
objectData collator that truncates or pads data for offline training.
- __init__(train_len)
Initialize with the target sequence length for truncation/padding.
- class OfflineSupervisedDataset
Bases:
DatasetOffline dataset for supervised fine-tuning with pre-dumped hidden states.
This dataset loads data on-the-fly from pre-processed .pt data files generated by the
compute_hidden_states_*scripts underexamples/speculative_decoding/collect_hidden_states/. Each .pt file contains a dict with the following keys:input_ids: token IDs of shape(seq_len,)hidden_states: base model last hidden states of shape(seq_len, hidden_size)aux_hidden_states: auxiliary hidden states of shape(seq_len, hidden_size)loss_mask(optional): per-token assistant mask of shape(seq_len,), present when the dump was produced with--answer-only-loss.
- Parameters:
dumped_files (list) – A list of file paths to the dumped .pt files.
answer_only_loss (bool) – If True, use the
loss_maskstored in each .pt file so that only assistant-produced tokens contribute to the loss. RaisesValueErroron__getitem__if the file lacksloss_mask. If False (default), a uniform all-ones mask is used regardless of what is stored in the file (backward compatible).
- __init__(dumped_files, answer_only_loss=False)
Initialize with a list of .pt file paths.
- Parameters:
answer_only_loss (bool)
- expand_mask(mask, dtype, tgt_len=None)
Expands attention_mask from [bsz, seq_len] to [bsz, 1, tgt_seq_len, src_seq_len].
- Parameters:
mask (Tensor)
dtype (dtype)
tgt_len (int | None)
- make_causal_mask(input_ids_shape, dtype, device, past_key_values_length=0)
Make causal mask used for bi-directional self-attention.
- Parameters:
input_ids_shape (Size)
dtype (dtype)
device (device)
past_key_values_length (int)