Dataset Concepts
Dataset Design Guidelines
Datasets should be sharded (i.e. a few samples per “shard”, like <1% of the dataset each, optimally more lots more shards than total number of workers used in training)
This allows for parallel loading of shards, split over the workers
With webdataset, the shards can be streamed (i.e. no random access, but iterate over the shards very fast)
The datasets are split across all ranks and workers to
TorchNet IndexedDatasets are also supported, they will be partially streamed, similar to shards
Virtual shards are created during training, i.e. an offset index and number of samples is resampled every time, then this portion is streamed for performance
All datasets are
torch.data.IterableDataset
s, i.e. they do not support random access via__getitem__
(i.e. nodataset[index]
, but you may only iterate over it).No concept of “epochs”, datasets are infinitely looped for training, concatenated for validation / testing
This gives more freedom in how the data is loaded
Enables streaming of shards (requirement for high-performance loading)
Enables blending of datasets (i.e. mix different datasets together with a weighted random sampler)
For validation / testing, shards are strided across workers, thus some workers may have less/no data.
Statistical Analysis of Dataset Loading
As each webdataset dataloader worker gets a all shards to work on, this is statistically fine.
After iterating over the dataset once, the shards are reshuffled, thus each “worker epoch”, every sample will be seen once (i.e. it is exactly balanced once every (total number of workers x total number of samples) samples have been iterated.
Because webdatasets shards are typically iterated linearly, lots of shuffling randomness potential is not available. We thus slice the shards into smaller parts (configured by
max_samples_per_sequence
), such that the shuffling is more fine-grained. Applying this, effectively, the size of shards does not matter too much any more at the performance cost of more seeking.
The
BlendDataset
will always yield full batches of one underlying dataset loader, but across different nodes/ranks(=GPUs), different batches will be yielded according to the blend weights.Typically, gradients are accumulated across ranks. Thus, the distribution should approximately match the given weights for a high total number of ranks (like at least 8 ranks).
If this behaviour is not desired, but mixing should happen within batches, the
MixBatchDataset
can be used instead.
The
GroupBatchDataset
will only yield as soon as a full batch of one group was collected. This could potentially lead to corner cases, such as that rare groups are filled very slowly (or even only with a single example). Currently, this is not used, thus it is not tested very well.Still, statistically, this should be fine over lots of samples, even if there is one unbalanced group, as it will eventually yield nevertheless.
Types
Following will show the type hierarchy of python classes.
Dataset Types / Flavors
These are the available dataset types for the dataset.yaml
.
Type hierarchy:
(
torch.data.IterableDataset
: All datasets implement the torchIterableDataset
interface)(
BaseCoreDataset
: Base class for all dataset types.)(
BaseWebdataset
: Webdataset based dataset consisting of sharded .tar files, basic flexible implementation.)DefaultGenericWebdataset
: Adds the sample loader / field map and also subflavors.DefaultDecoderWebdataset
: On top of theDefaultGenericWebdataset
, loads all known types, such as images or json or pkl types.CaptioningWebdataset
: YieldsCaptioningSample
from webdataset formatImageWebdataset
: YieldsImageSample
from webdataset formatOCRWebdataset
: YieldsOCRSample
from webdataset formatVQAWebdataset
: YieldsVQASample
from webdataset format
From the above, you will want to use the innermost (non-abstract) classes for your dataset.yaml
.
For an ocr dataset stored as a webdataset, you will use OCRWebdataset
.
Sample Types
These are the available sample types and their attributes yielded by the datasets above.
Type hierarchy:
Sample
: Base classAttributes:
__key__: str
: Unique identifier of the sample within the dataset. Useful for backtracking the source of a single sample.
CaptioningSample
: Represents a sample for captioningAttributes:
__key__: str
(inherited)image: torch.Tensor
: The input image tensorcaption: str
: The target caption string
ImageSample
: Represents a sample which only contains an image (e.g. for reconstruction)Attributes:
__key__: str
(inherited)image: torch.Tensor
: The image tensor
OCRSample
: Represents a sample which only contains ocr image and textAttributes:
__key__: str
(inherited)image: str
: The input imagetext: str
: The text string for the whole imagelines_boxes: Optional[torch.Tensor]
: The bounding boxes of the text lines in the imagelines_text: Optional[torch.Tensor]
: The text content of the text lines in the imagewords_boxes: Optional[torch.Tensor]
: The bounding boxes of the text words in the imagewords_text: Optional[torch.Tensor]
: The text content of the text words in the imagechars_boxes: Optional[torch.Tensor]
: The bounding boxes of the text characters in the imagechars_text: Optional[torch.Tensor]
: The text content of the text characters in the image
TextSample
: Represents a sample which only contains a text string (e.g. for text generation)Attributes:
__key__: str
(inherited)text: str
: The text string
VQASample
: Represents a sample which contains an image, a question/context and an answerAttributes:
__key__: str
(inherited)image: torch.Tensor
: The input image tensorcontext: str
: The target caption stringanswer: str
: The target caption string
Note
Images are always of shape (C, H, W)