Skip to content

Callbacks

PredictionWriter

Bases: BasePredictionWriter, Callback

A callback that writes predictions to disk at specified intervals during training.

Logits, Embeddings, Hiddens, Input IDs, and Labels may all be saved to the disk depending on trainer configuration. Batch Idxs are provided for each prediction in the same dictionary. These must be used to maintain order between multi device predictions and single device predictions.

Source code in bionemo/llm/utils/callbacks.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
class PredictionWriter(BasePredictionWriter, pl.Callback):
    """A callback that writes predictions to disk at specified intervals during training.

    Logits, Embeddings, Hiddens, Input IDs, and Labels may all be saved to the disk depending on trainer configuration.
    Batch Idxs are provided for each prediction in the same dictionary. These must be used to maintain order between
    multi device predictions and single device predictions.
    """

    def __init__(
        self,
        output_dir: str | os.PathLike,
        write_interval: IntervalT,
        batch_dim_key_defaults: dict[str, int] | None = None,
        seq_dim_key_defaults: dict[str, int] | None = None,
        save_all_model_parallel_ranks: bool = False,
        files_per_subdir: int | None = None,
    ):
        """Initializes the callback.

        Args:
            output_dir: The directory where predictions will be written.
            write_interval: The interval at which predictions will be written (batch, epoch). Epoch may not be used with
                multi-device trainers.
            batch_dim_key_defaults: The default batch dimension for each key, if different from the standard 0.
            seq_dim_key_defaults: The default sequence dimension for each key, if different from the standard 1.
            save_all_model_parallel_ranks: Whether to save predictions for all model parallel ranks. Generally these
                will be redundant.
            files_per_subdir: Number of files to write to each subdirectory. If provided, subdirectories with N files
                each will be created. Ignored unless write_interval is 'batch'.
        """
        super().__init__(write_interval)
        self.write_interval = write_interval
        self.output_dir = str(output_dir)
        self.base_dir = self.output_dir  # start out like this, but output_dir will be updated if files_per_subdir>0
        self.batch_dim_key_defaults = batch_dim_key_defaults
        self.seq_dim_key_defaults = seq_dim_key_defaults
        self.save_all_model_parallel_ranks = save_all_model_parallel_ranks
        self.files_per_subdir = files_per_subdir
        # Initialize to infinity if files_per_subdir is provided so that we create a new subdirectory before writing
        #   any files.
        self.num_files_written = float("inf") if files_per_subdir else 0
        self.num_subdirs_written = 0

    def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs) -> None:  # noqa: D417
        """Invoked with Trainer.fit, validate, test, and predict are called. Will immediately fail when 'write_interval' is 'epoch' and 'trainer.num_devices' > 1.

        Args:
            trainer: The Trainer instance.
            pl_module: The LightningModule instance.
        """
        if trainer.num_devices > 1 and self.write_interval == "epoch":
            logger.warning(
                "Multi-GPU predictions could result in shuffled inputs. Verify that the original indices are included "
                "in the model's predictions as outputs are not ordered and batch indices do not track input order."
            )

    @staticmethod
    def _assert_initialized():
        """Asserts that the environment is initialized."""
        if not (
            torch.distributed.is_available() and torch.distributed.is_initialized() and parallel_state.is_initialized()
        ):
            raise RuntimeError("This function is only defined within an initialized megatron parallel environment.")

    @property
    def data_parallel_world_size(self) -> int:
        """Returns the data parallel world size."""
        self._assert_initialized()
        return torch.distributed.get_world_size(parallel_state.get_data_parallel_group(with_context_parallel=False))

    @property
    def data_parallel_rank(self) -> int:
        """Returns the data parallel rank."""
        self._assert_initialized()
        return torch.distributed.get_rank(parallel_state.get_data_parallel_group(with_context_parallel=False))

    @property
    def should_write_predictions(self) -> bool:
        """Ensures that predictions are only written on TP/CP rank 0 and that it is the last stage of the pipeline."""
        self._assert_initialized()
        if not parallel_state.is_pipeline_last_stage():
            return False
        if self.save_all_model_parallel_ranks:
            return True
        # TODO: handle expert parallelism and other kinds of parallelism
        return parallel_state.get_tensor_model_parallel_rank() == 0 and parallel_state.get_context_parallel_rank() == 0

    @override
    def write_on_batch_end(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        prediction: Any,
        batch_indices: Sequence[int] | None,
        batch: Any,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        """Writes predictions to disk at the end of each batch.

        Predictions files follow the naming pattern, where rank is the active GPU in which the predictions were made.
        predictions__rank_{rank}__batch_{batch_idx}.pt

        Args:
            trainer: The Trainer instance.
            pl_module: The LightningModule instance.
            prediction: The prediction made by the model.
            batch_indices: The indices of the batch.
            batch: The batch data.
            batch_idx: The index of the batch.
            dataloader_idx: The index of the dataloader.
        """
        # this will create N (num processes) files in `output_dir` each containing
        # the predictions of it's respective rank
        if self.should_write_predictions:
            if (
                self.files_per_subdir is not None
                and (self.num_files_written * self.data_parallel_world_size) >= self.files_per_subdir
            ):
                self.num_subdirs_written += 1
                self.output_dir = os.path.join(self.base_dir, f"subdir_{self.num_subdirs_written}")
                os.makedirs(self.output_dir, exist_ok=True)
                self.num_files_written = 0
            result_path = os.path.join(
                self.output_dir,
                f"predictions__rank_{trainer.global_rank}__dp_rank_{self.data_parallel_rank}__batch_{batch_idx}.pt",
            )

            # batch_indices is not captured due to a lightning bug when return_predictions = False
            # we use input IDs in the prediction to map the result to input.

            # NOTE store the batch_idx so we do not need to rely on filenames for reconstruction of inputs. This is wrapped
            # in a tensor and list container to ensure compatibility with batch_collator.
            prediction["batch_idx"] = torch.tensor([batch_idx], dtype=torch.int64)

            torch.save(prediction, result_path)
            logger.info(f"Inference predictions are stored in {result_path}\n{prediction.keys()}")
            self.num_files_written += 1

    @override
    def write_on_epoch_end(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        predictions: Any,
        batch_indices: Sequence[int],
    ) -> None:
        """Writes predictions to disk at the end of each epoch.

        Writing all predictions on epoch end is memory intensive. It is recommended to use the batch writer instead for
        large predictions.

        Multi-device predictions will likely yield predictions in an order that is inconsistent with single device predictions and the input data.

        Args:
            trainer: The Trainer instance.
            pl_module: The LightningModule instance.
            predictions: The predictions made by the model.
            batch_indices: The indices of the batch.

        Raises:
            Multi-GPU predictions are output in an inconsistent order with multiple devices.
        """
        # this will create N (num processes) files in `output_dir` each containing
        # the predictions of it's respective rank
        if self.should_write_predictions:
            result_path = os.path.join(
                self.output_dir,
                f"predictions__rank_{trainer.global_rank}__dp_rank_{self.data_parallel_rank}.pt",
            )

            # collate multiple batches / ignore empty ones
            collate_kwargs = {}
            if self.batch_dim_key_defaults is not None:
                collate_kwargs["batch_dim_key_defaults"] = self.batch_dim_key_defaults
            if self.seq_dim_key_defaults is not None:
                collate_kwargs["seq_dim_key_defaults"] = self.seq_dim_key_defaults

            prediction = batch_collator([item for item in predictions if item is not None], **collate_kwargs)

            # batch_indices is not captured due to a lightning bug when return_predictions = False
            # we use input IDs in the prediction to map the result to input
            if isinstance(prediction, dict):
                keys = prediction.keys()
            else:
                keys = "tensor"
            torch.save(prediction, result_path)
            logger.info(f"Inference predictions are stored in {result_path}\n{keys}")

data_parallel_rank property

Returns the data parallel rank.

data_parallel_world_size property

Returns the data parallel world size.

should_write_predictions property

Ensures that predictions are only written on TP/CP rank 0 and that it is the last stage of the pipeline.

__init__(output_dir, write_interval, batch_dim_key_defaults=None, seq_dim_key_defaults=None, save_all_model_parallel_ranks=False, files_per_subdir=None)

Initializes the callback.

Parameters:

Name Type Description Default
output_dir str | PathLike

The directory where predictions will be written.

required
write_interval IntervalT

The interval at which predictions will be written (batch, epoch). Epoch may not be used with multi-device trainers.

required
batch_dim_key_defaults dict[str, int] | None

The default batch dimension for each key, if different from the standard 0.

None
seq_dim_key_defaults dict[str, int] | None

The default sequence dimension for each key, if different from the standard 1.

None
save_all_model_parallel_ranks bool

Whether to save predictions for all model parallel ranks. Generally these will be redundant.

False
files_per_subdir int | None

Number of files to write to each subdirectory. If provided, subdirectories with N files each will be created. Ignored unless write_interval is 'batch'.

None
Source code in bionemo/llm/utils/callbacks.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def __init__(
    self,
    output_dir: str | os.PathLike,
    write_interval: IntervalT,
    batch_dim_key_defaults: dict[str, int] | None = None,
    seq_dim_key_defaults: dict[str, int] | None = None,
    save_all_model_parallel_ranks: bool = False,
    files_per_subdir: int | None = None,
):
    """Initializes the callback.

    Args:
        output_dir: The directory where predictions will be written.
        write_interval: The interval at which predictions will be written (batch, epoch). Epoch may not be used with
            multi-device trainers.
        batch_dim_key_defaults: The default batch dimension for each key, if different from the standard 0.
        seq_dim_key_defaults: The default sequence dimension for each key, if different from the standard 1.
        save_all_model_parallel_ranks: Whether to save predictions for all model parallel ranks. Generally these
            will be redundant.
        files_per_subdir: Number of files to write to each subdirectory. If provided, subdirectories with N files
            each will be created. Ignored unless write_interval is 'batch'.
    """
    super().__init__(write_interval)
    self.write_interval = write_interval
    self.output_dir = str(output_dir)
    self.base_dir = self.output_dir  # start out like this, but output_dir will be updated if files_per_subdir>0
    self.batch_dim_key_defaults = batch_dim_key_defaults
    self.seq_dim_key_defaults = seq_dim_key_defaults
    self.save_all_model_parallel_ranks = save_all_model_parallel_ranks
    self.files_per_subdir = files_per_subdir
    # Initialize to infinity if files_per_subdir is provided so that we create a new subdirectory before writing
    #   any files.
    self.num_files_written = float("inf") if files_per_subdir else 0
    self.num_subdirs_written = 0

setup(trainer, pl_module, *args, **kwargs)

Invoked with Trainer.fit, validate, test, and predict are called. Will immediately fail when 'write_interval' is 'epoch' and 'trainer.num_devices' > 1.

Parameters:

Name Type Description Default
trainer Trainer

The Trainer instance.

required
pl_module LightningModule

The LightningModule instance.

required
Source code in bionemo/llm/utils/callbacks.py
81
82
83
84
85
86
87
88
89
90
91
92
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs) -> None:  # noqa: D417
    """Invoked with Trainer.fit, validate, test, and predict are called. Will immediately fail when 'write_interval' is 'epoch' and 'trainer.num_devices' > 1.

    Args:
        trainer: The Trainer instance.
        pl_module: The LightningModule instance.
    """
    if trainer.num_devices > 1 and self.write_interval == "epoch":
        logger.warning(
            "Multi-GPU predictions could result in shuffled inputs. Verify that the original indices are included "
            "in the model's predictions as outputs are not ordered and batch indices do not track input order."
        )

write_on_batch_end(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx)

Writes predictions to disk at the end of each batch.

Predictions files follow the naming pattern, where rank is the active GPU in which the predictions were made. predictions__rank_{rank}__batch_{batch_idx}.pt

Parameters:

Name Type Description Default
trainer Trainer

The Trainer instance.

required
pl_module LightningModule

The LightningModule instance.

required
prediction Any

The prediction made by the model.

required
batch_indices Sequence[int] | None

The indices of the batch.

required
batch Any

The batch data.

required
batch_idx int

The index of the batch.

required
dataloader_idx int

The index of the dataloader.

required
Source code in bionemo/llm/utils/callbacks.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
@override
def write_on_batch_end(
    self,
    trainer: pl.Trainer,
    pl_module: pl.LightningModule,
    prediction: Any,
    batch_indices: Sequence[int] | None,
    batch: Any,
    batch_idx: int,
    dataloader_idx: int,
) -> None:
    """Writes predictions to disk at the end of each batch.

    Predictions files follow the naming pattern, where rank is the active GPU in which the predictions were made.
    predictions__rank_{rank}__batch_{batch_idx}.pt

    Args:
        trainer: The Trainer instance.
        pl_module: The LightningModule instance.
        prediction: The prediction made by the model.
        batch_indices: The indices of the batch.
        batch: The batch data.
        batch_idx: The index of the batch.
        dataloader_idx: The index of the dataloader.
    """
    # this will create N (num processes) files in `output_dir` each containing
    # the predictions of it's respective rank
    if self.should_write_predictions:
        if (
            self.files_per_subdir is not None
            and (self.num_files_written * self.data_parallel_world_size) >= self.files_per_subdir
        ):
            self.num_subdirs_written += 1
            self.output_dir = os.path.join(self.base_dir, f"subdir_{self.num_subdirs_written}")
            os.makedirs(self.output_dir, exist_ok=True)
            self.num_files_written = 0
        result_path = os.path.join(
            self.output_dir,
            f"predictions__rank_{trainer.global_rank}__dp_rank_{self.data_parallel_rank}__batch_{batch_idx}.pt",
        )

        # batch_indices is not captured due to a lightning bug when return_predictions = False
        # we use input IDs in the prediction to map the result to input.

        # NOTE store the batch_idx so we do not need to rely on filenames for reconstruction of inputs. This is wrapped
        # in a tensor and list container to ensure compatibility with batch_collator.
        prediction["batch_idx"] = torch.tensor([batch_idx], dtype=torch.int64)

        torch.save(prediction, result_path)
        logger.info(f"Inference predictions are stored in {result_path}\n{prediction.keys()}")
        self.num_files_written += 1

write_on_epoch_end(trainer, pl_module, predictions, batch_indices)

Writes predictions to disk at the end of each epoch.

Writing all predictions on epoch end is memory intensive. It is recommended to use the batch writer instead for large predictions.

Multi-device predictions will likely yield predictions in an order that is inconsistent with single device predictions and the input data.

Parameters:

Name Type Description Default
trainer Trainer

The Trainer instance.

required
pl_module LightningModule

The LightningModule instance.

required
predictions Any

The predictions made by the model.

required
batch_indices Sequence[int]

The indices of the batch.

required
Source code in bionemo/llm/utils/callbacks.py
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
@override
def write_on_epoch_end(
    self,
    trainer: pl.Trainer,
    pl_module: pl.LightningModule,
    predictions: Any,
    batch_indices: Sequence[int],
) -> None:
    """Writes predictions to disk at the end of each epoch.

    Writing all predictions on epoch end is memory intensive. It is recommended to use the batch writer instead for
    large predictions.

    Multi-device predictions will likely yield predictions in an order that is inconsistent with single device predictions and the input data.

    Args:
        trainer: The Trainer instance.
        pl_module: The LightningModule instance.
        predictions: The predictions made by the model.
        batch_indices: The indices of the batch.

    Raises:
        Multi-GPU predictions are output in an inconsistent order with multiple devices.
    """
    # this will create N (num processes) files in `output_dir` each containing
    # the predictions of it's respective rank
    if self.should_write_predictions:
        result_path = os.path.join(
            self.output_dir,
            f"predictions__rank_{trainer.global_rank}__dp_rank_{self.data_parallel_rank}.pt",
        )

        # collate multiple batches / ignore empty ones
        collate_kwargs = {}
        if self.batch_dim_key_defaults is not None:
            collate_kwargs["batch_dim_key_defaults"] = self.batch_dim_key_defaults
        if self.seq_dim_key_defaults is not None:
            collate_kwargs["seq_dim_key_defaults"] = self.seq_dim_key_defaults

        prediction = batch_collator([item for item in predictions if item is not None], **collate_kwargs)

        # batch_indices is not captured due to a lightning bug when return_predictions = False
        # we use input IDs in the prediction to map the result to input
        if isinstance(prediction, dict):
            keys = prediction.keys()
        else:
            keys = "tensor"
        torch.save(prediction, result_path)
        logger.info(f"Inference predictions are stored in {result_path}\n{keys}")