Skip to content

Resamplers

PRNGResampleDataset

Bases: Dataset[T_co]

A thread-safe dataset shuffler that uses a pseudo-random number generator (PRNG) to shuffle the dataset.

PRNGResampleDataset shuffles a given dataset using a pseudo-random number generator (PRNG). This allows for reproducible shuffling by controlling the random seed, while not ever storing the list of indices in memory. It works by generating random indices assuming that the requesting function asks for them sequentially. Although random lookups are supported, random lookups will involve recomputing state which is slow, and involves linearly advancing from 0 if the last requested index was greater than or equal to this requested index. This should work well with the megatron sampler which is sequential. It handles skipped lookups as will happen with multiple workers by not generating those numbers.

Prefer bionemo.core.data.multi_epoch_dataset.MultiEpochDatasetResampler

This class performs sampling with replacement of an underlying dataset. It is recommended to use the epoch-based sampling provided by bionemo.core.data.multi_epoch_dataset.MultiEpochDatasetResampler instead, which ensures that each sample is seen exactly once per epoch. This dataset is useful for cases where the dataset is too large for the shuffled list of indices to fit in memory and exhaustive sampling is not required.

Source code in bionemo/core/data/resamplers.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class PRNGResampleDataset(Dataset[T_co]):
    """A thread-safe dataset shuffler that uses a pseudo-random number generator (PRNG) to shuffle the dataset.

    PRNGResampleDataset shuffles a given dataset using a pseudo-random number generator (PRNG). This allows for
    reproducible shuffling by controlling the random seed, while not ever storing the list of indices in memory. It
    works by generating random indices assuming that the requesting function asks for them sequentially. Although random
    lookups are supported, random lookups will involve recomputing state which is slow, and involves linearly advancing
    from 0 if the last requested index was greater than or equal to this requested index. This should work well with the
    megatron sampler which is sequential. It handles skipped lookups as will happen with multiple workers by not
    generating those numbers.

    !!! warning "Prefer bionemo.core.data.multi_epoch_dataset.MultiEpochDatasetResampler"

        This class performs sampling with replacement of an underlying dataset. It is recommended to use the epoch-based
        sampling provided by `bionemo.core.data.multi_epoch_dataset.MultiEpochDatasetResampler` instead, which ensures
        that each sample is seen exactly once per epoch. This dataset is useful for cases where the dataset is too large
        for the shuffled list of indices to fit in memory and exhaustive sampling is not required.
    """

    def __init__(self, dataset: Dataset[T_co], seed: int = 42, num_samples: Optional[int] = None):
        """Initializes the PRNGResampleDataset.

        Args:
            dataset: The dataset to be shuffled.
            seed: The seed value for the PRNG. Default is 42.
            num_samples: The number of samples to draw from the dataset.
                If None, the length of the dataset is used. Default is None.
        """
        self.initial_seed = seed
        self.rng = random.Random(seed)
        self.dataset_len = len(dataset)  # type: ignore
        self.num_samples = num_samples if num_samples is not None else len(dataset)
        self.dataset = dataset
        # Store the last accessed index. On this first pass this is initialized to infinity, which will trigger a reset since
        #  index - inf < 0 for all values of index. This will lead to `self.advance_state(index)` being called which will advance
        #  the state to the correct starting index. The last_index will be then be replaced by `index` in that case and the algorithm
        #  will proceed normally.
        self.last_index: Union[int, math.inf] = math.inf
        self.last_rand_index: Optional[int] = None

    def rand_idx(self) -> int:
        """Generates a random index within the range of the dataset size."""
        return self.rng.randint(0, self.dataset_len - 1)

    def advance_state(self, num_to_advance: int):
        """Advances the PRNG state by generating n_to_advance random indices.

        Args:
            num_to_advance: The number of random state steps to advance.
        """
        for _ in range(num_to_advance):
            self.rand_idx()

    def __getitem__(self, index: int) -> T_co:
        """Returns the item from the dataset at the specified index.

        Args:
            index: The index of the item to retrieve.

        Returns:
            The item from the dataset at the specified index.

        Note:
            If the requested index is before the last accessed index, the PRNG state is reset to the initial seed
            and advanced to the correct state. This is less efficient than advancing forward.
        """
        idx_diff = index - self.last_index
        if idx_diff < 0:
            # We need to go backwards (or it is the first call), which involves resetting to the initial seed and
            #   then advancing to just before the correct index, which is accomplished with `range(index)`.
            self.rng = random.Random(self.initial_seed)
            self.advance_state(index)
        elif idx_diff == 0:
            # If the index is the same as the last index, we can just return the last random index that was generated.
            #  no state needs to be updated in this case so just return.
            return self.dataset[self.last_rand_index]
        else:
            # We need to advance however many steps were skipped since the last call. Since i+1 - i = 1, we need to advance
            #  by `idx_diff - 1` to accomodate for skipped indices.
            self.advance_state(idx_diff - 1)
        self.last_index = index
        self.last_rand_index = (
            self.rand_idx()
        )  # store the last index called incase the user wants to requrest this index again.
        return self.dataset[self.last_rand_index]  # Advances state by 1

    def __len__(self) -> int:
        """Returns the total number of samples in the dataset."""
        return self.num_samples

__getitem__(index)

Returns the item from the dataset at the specified index.

Parameters:

Name Type Description Default
index int

The index of the item to retrieve.

required

Returns:

Type Description
T_co

The item from the dataset at the specified index.

Note

If the requested index is before the last accessed index, the PRNG state is reset to the initial seed and advanced to the correct state. This is less efficient than advancing forward.

Source code in bionemo/core/data/resamplers.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def __getitem__(self, index: int) -> T_co:
    """Returns the item from the dataset at the specified index.

    Args:
        index: The index of the item to retrieve.

    Returns:
        The item from the dataset at the specified index.

    Note:
        If the requested index is before the last accessed index, the PRNG state is reset to the initial seed
        and advanced to the correct state. This is less efficient than advancing forward.
    """
    idx_diff = index - self.last_index
    if idx_diff < 0:
        # We need to go backwards (or it is the first call), which involves resetting to the initial seed and
        #   then advancing to just before the correct index, which is accomplished with `range(index)`.
        self.rng = random.Random(self.initial_seed)
        self.advance_state(index)
    elif idx_diff == 0:
        # If the index is the same as the last index, we can just return the last random index that was generated.
        #  no state needs to be updated in this case so just return.
        return self.dataset[self.last_rand_index]
    else:
        # We need to advance however many steps were skipped since the last call. Since i+1 - i = 1, we need to advance
        #  by `idx_diff - 1` to accomodate for skipped indices.
        self.advance_state(idx_diff - 1)
    self.last_index = index
    self.last_rand_index = (
        self.rand_idx()
    )  # store the last index called incase the user wants to requrest this index again.
    return self.dataset[self.last_rand_index]  # Advances state by 1

__init__(dataset, seed=42, num_samples=None)

Initializes the PRNGResampleDataset.

Parameters:

Name Type Description Default
dataset Dataset[T_co]

The dataset to be shuffled.

required
seed int

The seed value for the PRNG. Default is 42.

42
num_samples Optional[int]

The number of samples to draw from the dataset. If None, the length of the dataset is used. Default is None.

None
Source code in bionemo/core/data/resamplers.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def __init__(self, dataset: Dataset[T_co], seed: int = 42, num_samples: Optional[int] = None):
    """Initializes the PRNGResampleDataset.

    Args:
        dataset: The dataset to be shuffled.
        seed: The seed value for the PRNG. Default is 42.
        num_samples: The number of samples to draw from the dataset.
            If None, the length of the dataset is used. Default is None.
    """
    self.initial_seed = seed
    self.rng = random.Random(seed)
    self.dataset_len = len(dataset)  # type: ignore
    self.num_samples = num_samples if num_samples is not None else len(dataset)
    self.dataset = dataset
    # Store the last accessed index. On this first pass this is initialized to infinity, which will trigger a reset since
    #  index - inf < 0 for all values of index. This will lead to `self.advance_state(index)` being called which will advance
    #  the state to the correct starting index. The last_index will be then be replaced by `index` in that case and the algorithm
    #  will proceed normally.
    self.last_index: Union[int, math.inf] = math.inf
    self.last_rand_index: Optional[int] = None

__len__()

Returns the total number of samples in the dataset.

Source code in bionemo/core/data/resamplers.py
115
116
117
def __len__(self) -> int:
    """Returns the total number of samples in the dataset."""
    return self.num_samples

advance_state(num_to_advance)

Advances the PRNG state by generating n_to_advance random indices.

Parameters:

Name Type Description Default
num_to_advance int

The number of random state steps to advance.

required
Source code in bionemo/core/data/resamplers.py
73
74
75
76
77
78
79
80
def advance_state(self, num_to_advance: int):
    """Advances the PRNG state by generating n_to_advance random indices.

    Args:
        num_to_advance: The number of random state steps to advance.
    """
    for _ in range(num_to_advance):
        self.rand_idx()

rand_idx()

Generates a random index within the range of the dataset size.

Source code in bionemo/core/data/resamplers.py
69
70
71
def rand_idx(self) -> int:
    """Generates a random index within the range of the dataset size."""
    return self.rng.randint(0, self.dataset_len - 1)