Skip to content

Multi epoch dataset

EpochIndex

Bases: NamedTuple

A tuple that contains both the current epoch and index for multi-epoch training.

Source code in bionemo/core/data/multi_epoch_dataset.py
42
43
44
45
46
47
48
49
class EpochIndex(NamedTuple):
    """A tuple that contains both the current epoch and index for multi-epoch training."""

    epoch: int
    """An integer representing the current epoch."""

    idx: int
    """An integer representing the index within the current epoch."""

epoch instance-attribute

An integer representing the current epoch.

idx instance-attribute

An integer representing the index within the current epoch.

IdentityMultiEpochDatasetWrapper dataclass

Bases: MultiEpochDatasetWrapper[T, T]

An implementation of the MultiEpochDatasetWrapper that does not apply any transformations.

Source code in bionemo/core/data/multi_epoch_dataset.py
177
178
179
180
181
182
183
class IdentityMultiEpochDatasetWrapper(MultiEpochDatasetWrapper[T, T]):
    """An implementation of the `MultiEpochDatasetWrapper` that does not apply any transformations."""

    def apply_transform(self, sample: T, index: EpochIndex) -> T:
        """Return the sample as is."""
        del index  # Unused.
        return sample

apply_transform(sample, index)

Return the sample as is.

Source code in bionemo/core/data/multi_epoch_dataset.py
180
181
182
183
def apply_transform(self, sample: T, index: EpochIndex) -> T:
    """Return the sample as is."""
    del index  # Unused.
    return sample

MultiEpochDataset

Bases: Protocol[T_co]

A protocol for datasets for multi-epoch training in Megatron-LM.

Dataset determinism in Megatron-LM

In megatron training, the sampler and dataset objects are used to ensure consistent data loading across model-parallel ranks. For datasets to work with megatron training, they must return exactly the same data for every call to __getitem__ with the same index.

Source code in bionemo/core/data/multi_epoch_dataset.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class MultiEpochDataset(Protocol[T_co]):
    """A protocol for datasets for multi-epoch training in Megatron-LM.

    !!! important "Dataset determinism in Megatron-LM"
        In megatron training, the sampler and dataset objects are used to ensure consistent data loading across
        model-parallel ranks. For datasets to work with megatron training, they must return exactly the same data for
        every call to `__getitem__` with the same index.
    """

    def __getitem__(self, index: EpochIndex) -> T_co:  # noqa: D105
        ...

    def __len__(self) -> int:  # noqa: D105
        ...

MultiEpochDatasetResampler dataclass

Bases: Dataset[T_co]

A dataset wrapper class that converts the sequential sampling from Megatron-LM to epoch-based sampling.

Either num_epochs or num_samples should be provided. If neither are provided, the dataset will use a single epoch. If num_epochs is given, the resampled dataset will have len(dataset) * num_epochs samples. If num_samples the resampled dataset will have num_samples samples. For num_samples, the dataset will be repeated for multiple epochs until the desired number of samples is reached (with the final epoch being truncated).

Source code in bionemo/core/data/multi_epoch_dataset.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
@dataclass
class MultiEpochDatasetResampler(Dataset[T_co]):
    """A dataset wrapper class that converts the sequential sampling from Megatron-LM to epoch-based sampling.

    Either `num_epochs` or `num_samples` should be provided. If neither are provided, the dataset will use a single
    epoch. If `num_epochs` is given, the resampled dataset will have `len(dataset) * num_epochs` samples. If
    `num_samples` the resampled dataset will have `num_samples` samples. For `num_samples`, the dataset will be repeated
    for multiple epochs until the desired number of samples is reached (with the final epoch being truncated).
    """

    dataset: MultiEpochDataset[T_co]
    """The dataset to resample. Must support indexing with an `EpochIndex`."""

    num_epochs: int | None = None
    """The total number of epochs. The length of the resampled dataset will be len(dataset) * num_epochs."""

    num_samples: int | None = None
    """The total number of samples to draw.

    The number of epochs will be determined by the number of samples and the length of the dataset.
    """

    shuffle: bool = True
    """Whether to shuffle the samples in the dataset each epoch."""

    seed: int = 42  # type: ignore
    """A random seed for reproducibility."""

    def __post_init__(self):
        """Pre-shuffle each epoch's samples."""
        if self.num_epochs is None and self.num_samples is None:
            self.num_epochs = 1
        elif self.num_epochs is not None and self.num_samples is not None:
            raise ValueError("Only one of num_epochs and num_samples should be provided.")

        if self.num_epochs is None and self.num_samples is not None:
            self.num_epochs = math.ceil(self.num_samples / len(self.dataset))

        elif self.num_samples is None and self.num_epochs is not None:
            self.num_samples = len(self.dataset) * self.num_epochs

        # Type guard statements, the above if/elif block should ensure these are not None.
        assert self.num_epochs is not None
        assert self.num_samples is not None

        if self.num_epochs < 1:
            raise ValueError("num_epochs must be at least 1.")

        rng = np.random.default_rng(self.seed)

        # Initialize a vector of random seeds so that each epoch is shuffled differently.
        self.epoch_seeds = rng.integers(0, np.iinfo(np.int32).max, size=self.num_epochs)

    def __getitem__(self, index: int) -> T_co:
        """Get the sample at the given index."""
        if index not in range(len(self)):
            raise IndexError(f"Index {index} out of bounds for dataset of length {len(self)}.")
        return self.dataset[self._global_index_to_permuted_local_index(index)]

    def __len__(self) -> int:
        """Return the length of the resampled dataset."""
        return self.num_samples  # type: ignore

    def _global_index_to_permuted_local_index(self, index: int) -> EpochIndex:
        """Convert a global index to an epoch index."""
        epoch = index // len(self.dataset)
        idx = index % len(self.dataset)
        if self.shuffle:
            idx = permute(idx, len(self.dataset), self.epoch_seeds[epoch])
        return EpochIndex(epoch, idx)

dataset instance-attribute

The dataset to resample. Must support indexing with an EpochIndex.

num_epochs = None class-attribute instance-attribute

The total number of epochs. The length of the resampled dataset will be len(dataset) * num_epochs.

num_samples = None class-attribute instance-attribute

The total number of samples to draw.

The number of epochs will be determined by the number of samples and the length of the dataset.

seed = 42 class-attribute instance-attribute

A random seed for reproducibility.

shuffle = True class-attribute instance-attribute

Whether to shuffle the samples in the dataset each epoch.

__getitem__(index)

Get the sample at the given index.

Source code in bionemo/core/data/multi_epoch_dataset.py
131
132
133
134
135
def __getitem__(self, index: int) -> T_co:
    """Get the sample at the given index."""
    if index not in range(len(self)):
        raise IndexError(f"Index {index} out of bounds for dataset of length {len(self)}.")
    return self.dataset[self._global_index_to_permuted_local_index(index)]

__len__()

Return the length of the resampled dataset.

Source code in bionemo/core/data/multi_epoch_dataset.py
137
138
139
def __len__(self) -> int:
    """Return the length of the resampled dataset."""
    return self.num_samples  # type: ignore

__post_init__()

Pre-shuffle each epoch's samples.

Source code in bionemo/core/data/multi_epoch_dataset.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def __post_init__(self):
    """Pre-shuffle each epoch's samples."""
    if self.num_epochs is None and self.num_samples is None:
        self.num_epochs = 1
    elif self.num_epochs is not None and self.num_samples is not None:
        raise ValueError("Only one of num_epochs and num_samples should be provided.")

    if self.num_epochs is None and self.num_samples is not None:
        self.num_epochs = math.ceil(self.num_samples / len(self.dataset))

    elif self.num_samples is None and self.num_epochs is not None:
        self.num_samples = len(self.dataset) * self.num_epochs

    # Type guard statements, the above if/elif block should ensure these are not None.
    assert self.num_epochs is not None
    assert self.num_samples is not None

    if self.num_epochs < 1:
        raise ValueError("num_epochs must be at least 1.")

    rng = np.random.default_rng(self.seed)

    # Initialize a vector of random seeds so that each epoch is shuffled differently.
    self.epoch_seeds = rng.integers(0, np.iinfo(np.int32).max, size=self.num_epochs)

MultiEpochDatasetWrapper dataclass

Bases: Dataset[U_co], Generic[T, U_co], ABC

A wrapper to convert a standard pytorch dataset into one that supports multi-epoch megatron training.

The underlying dataset's getitem method must be deterministic, i.e. it must return the same data for the same index every time it is called. If there are any non-deterministic operations, they should be moved to the apply_transform method. This method must also be deterministic for every (epoch, index) pair, but it can use the epoch to implement data augmentation each epoch.

Source code in bionemo/core/data/multi_epoch_dataset.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
@dataclass
class MultiEpochDatasetWrapper(Dataset[U_co], Generic[T, U_co], ABC):
    """A wrapper to convert a standard pytorch dataset into one that supports multi-epoch megatron training.

    The underlying dataset's __getitem__ method must be deterministic, i.e. it must return the same data for the same
    index every time it is called. If there are any non-deterministic operations, they should be moved to the
    `apply_transform` method. This method must also be deterministic for every (epoch, index) pair, but it can use
    the epoch to implement data augmentation each epoch.
    """

    dataset: SizedDataset[T]
    """A deterministic dataset that supports indexing with an integer index."""

    @abstractmethod
    def apply_transform(self, sample: T, index: EpochIndex) -> U_co:
        """Apply any transformations to the sample for the given epoch."""
        raise NotImplementedError

    def __getitem__(self, index: EpochIndex) -> U_co:
        """Get the sample at the given epoch and index."""
        return self.apply_transform(self.dataset[index.idx], index)

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.dataset)

dataset instance-attribute

A deterministic dataset that supports indexing with an integer index.

__getitem__(index)

Get the sample at the given epoch and index.

Source code in bionemo/core/data/multi_epoch_dataset.py
168
169
170
def __getitem__(self, index: EpochIndex) -> U_co:
    """Get the sample at the given epoch and index."""
    return self.apply_transform(self.dataset[index.idx], index)

__len__()

Return the length of the dataset.

Source code in bionemo/core/data/multi_epoch_dataset.py
172
173
174
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.dataset)

apply_transform(sample, index) abstractmethod

Apply any transformations to the sample for the given epoch.

Source code in bionemo/core/data/multi_epoch_dataset.py
163
164
165
166
@abstractmethod
def apply_transform(self, sample: T, index: EpochIndex) -> U_co:
    """Apply any transformations to the sample for the given epoch."""
    raise NotImplementedError

SizedDataset

Bases: Protocol[T_co]

A protocol for integer-indexed datasets that have a fixed length.

Source code in bionemo/core/data/multi_epoch_dataset.py
52
53
54
55
56
57
58
59
class SizedDataset(Protocol[T_co]):
    """A protocol for integer-indexed datasets that have a fixed length."""

    def __getitem__(self, index: int) -> T_co:  # noqa: D105
        ...

    def __len__(self) -> int:  # noqa: D105
        ...