utils
Eagle model utils.
Classes
Data collator that truncates or pads data for offline training. |
|
Offline dataset for supervised fine-tuning with pre-dumped hidden states. |
Functions
Expands attention_mask from [bsz, seq_len] to [bsz, 1, tgt_seq_len, src_seq_len]. |
|
Make causal mask used for bi-directional self-attention. |
- class EagleOfflineDataCollator
Bases:
objectData collator that truncates or pads data for offline training.
- __init__(train_len)
Initialize with the target sequence length for truncation/padding.
- class OfflineSupervisedDataset
Bases:
DatasetOffline dataset for supervised fine-tuning with pre-dumped hidden states.
This dataset loads data on-the-fly from pre-processed .pt data files generated by
examples/speculative_decoding/main.py --mode dump_offline_data. Each .pt file contains a dict with the following keys:input_ids: token IDs of shape(seq_len,)hidden_states: base model last hidden states of shape(seq_len, hidden_size)aux_hidden_states: auxiliary hidden states of shape(seq_len, hidden_size)base_model_input_embeds: input embeddings of shape(seq_len, hidden_size)
- Parameters:
dumped_files (list) – A list of file paths to the dumped .pt files.
- __init__(dumped_files)
Initialize with a list of .pt file paths.
- expand_mask(mask, dtype, tgt_len=None)
Expands attention_mask from [bsz, seq_len] to [bsz, 1, tgt_seq_len, src_seq_len].
- Parameters:
mask (Tensor)
dtype (dtype)
tgt_len (int | None)
- make_causal_mask(input_ids_shape, dtype, device, past_key_values_length=0)
Make causal mask used for bi-directional self-attention.
- Parameters:
input_ids_shape (Size)
dtype (dtype)
device (device)
past_key_values_length (int)