dataset_utils
Utility functions for getting samples and forward loop function for different datasets.
Functions
Creates and returns a forward loop function configured for a specific model, dataset, and tokenizer. |
|
Get a dataloader with the dataset name and toknizer of the target model. |
|
Get the maximum batch size that can be used for the model. |
|
Retrieves a list of datasets supported. |
- create_forward_loop(model=None, dataset_name='cnn_dailymail', tokenizer=None, batch_size=1, num_samples=512, max_sample_length=512, device=None, include_labels=False, dataloader=None)
Creates and returns a forward loop function configured for a specific model, dataset, and tokenizer.
This function initializes a forward loop function tailored to process batches of data from the specified dataset using the given model and tokenizer. The forward loop function, when called, iterates over the dataset, applies the tokenizer to prepare the input data, feeds it into the model, and returns the model’s predictions.
- Parameters:
model (Module) – The PyTorch model for inference.
dataset_name (str) – The name of the dataset to be used. Must be one of the datasets in get_supported_datasets().
tokenizer (PreTrainedTokenizer | PreTrainedTokenizerFast) – The tokenizer used to preprocess text data into a format suitable for the model.
batch_size (int) – Batch size of the returned dataloader. If 0 is provided, we auto determine the batch_size.
num_samples (int) – Number of samples from the dataset.
max_sample_length (int) – Maximum length of a sample.
device (str | None) – Target device for the returned dataloader.
include_labels (bool) – Whether to include labels in the dataloader.
dataloader (DataLoader) – If provided, use the provided dataloader instead.
- Return type:
Callable
Example usage for quantization:
import modelopt.torch.quantization as mtq from modelopt.torch.utils import create_forward_loop # Initialize model and tokenizer # ... # Create forward loop for calibration forward_loop = create_forward_loop( model=model, dataset_name="cnn_dailymail", tokenizer=tokenizer ) # Quantize the model with the calibration dataset mtq.quantize(model, quant_cfg, forward_loop=forward_loop)
- Returns:
- A forward loop function that can be called with no arguments. When called, this function iterates over
the dataset specified by dataset_name.
- Parameters:
model (Module) –
dataset_name (str) –
tokenizer (PreTrainedTokenizer | PreTrainedTokenizerFast) –
batch_size (int) –
num_samples (int) –
max_sample_length (int) –
device (str | None) –
include_labels (bool) –
dataloader (DataLoader) –
- Return type:
Callable
- get_dataset_dataloader(dataset_name='cnn_dailymail', tokenizer=None, batch_size=1, num_samples=512, max_sample_length=512, device=None, include_labels=False)
Get a dataloader with the dataset name and toknizer of the target model.
- Parameters:
dataset_name (str) – Name of the dataset to load.
tokenizer (PreTrainedTokenizer | PreTrainedTokenizerFast) – Instancne of Hugginface tokenizer.
batch_size (int) – Batch size of the returned dataloader.
num_samples (int) – Number of samples from the dataset.
max_sample_length (int) – Maximum length of a sample.
device (str | None) – Target device for the returned dataloader.
include_labels (bool) – Whether to include labels in the dataloader.
- Returns:
A instance of dataloader.
- Return type:
DataLoader
- get_max_batch_size(model, max_sample_length=512, sample_memory_usage_ratio=1.0)
Get the maximum batch size that can be used for the model.
- Parameters:
model (Module) –
max_sample_length (int) –
sample_memory_usage_ratio (float) –
- get_supported_datasets()
Retrieves a list of datasets supported.
- Returns:
A list of strings, where each string is the name of a supported dataset.
- Return type:
list[str]
Example usage:
from modelopt.torch.utils import get_supported_datasets print("Supported datasets:", get_supported_datasets())