dataset

Classes

ConstantLengthDataset

Iterable dataset that returns constant length chunks of tokens from stream of text files.

Functions

permute

Take in a sample (list of tokens) and perform a FIM transformation on it with a probability of fim_rate, using two FIM modes: PSM and SPM (with a probability of fim_spm_rate).

get_fim_token_ids

class ConstantLengthDataset

Bases: IterableDataset

Iterable dataset that returns constant length chunks of tokens from stream of text files.

Parameters:
  • tokenizer (Tokenizer) – The processor used for proccessing the data.

  • dataset (dataset.Dataset) – Dataset with text files.

  • infinite (bool) – If True the iterator is reset after dataset reaches end else stops.

  • seq_length (int) – Length of token sequences to return.

  • num_of_sequences (int) – Number of token sequences to keep in buffer.

  • chars_per_token (int) – Number of characters per token used to estimate number of tokens in text buffer.

  • fim_rate (float) – Rate (0.0 to 1.0) that sample will be permuted with FIM.

  • fim_spm_rate (float) – Rate (0.0 to 1.0) of FIM permuations that will use SPM.

  • seed (int) – Seed for random number generator.

  • label_shift (bool) – Whether to shift labels by 1 or not.

__init__(tokenizer, dataset, infinite=False, seq_length=1024, num_of_sequences=1024, chars_per_token=3.6, content_field='content', fim_rate=0.5, fim_spm_rate=0.5, seed=0, label_shift=True, max_sample_length=200000, tokens_field='token_ids', source_datasets_to_discard=(), bos_rate=1.0, return_cu_seqlens=False, seqlen_cap=None)
Parameters:
  • source_datasets_to_discard (Sequence[str] | None)

  • bos_rate (float)

  • return_cu_seqlens (bool)

  • seqlen_cap (int | None)

prepare_cu_seqlens(input_ids)
get_fim_token_ids(tokenizer)
permute(sample, np_rng, fim_token_ids, fim_rate=0.5, fim_spm_rate=0.5, truncate_or_pad=False)

Take in a sample (list of tokens) and perform a FIM transformation on it with a probability of fim_rate, using two FIM modes: PSM and SPM (with a probability of fim_spm_rate).