Skip to content

Callbacks

PredictionWriter

Bases: BasePredictionWriter, Callback

A callback that writes predictions to disk at specified intervals during training.

Source code in bionemo/llm/utils/callbacks.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class PredictionWriter(BasePredictionWriter, pl.Callback):
    """A callback that writes predictions to disk at specified intervals during training."""

    def __init__(self, output_dir: str | os.PathLike, write_interval: IntervalT):
        """Initializes the callback.

        Args:
            output_dir: The directory where predictions will be written.
            write_interval: The interval at which predictions will be written. (batch, epoch)

        """
        super().__init__(write_interval)
        self.output_dir = str(output_dir)

    def write_on_batch_end(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        prediction: Any,
        batch_indices: Sequence[int],
        batch: Any,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        """Writes predictions to disk at the end of each batch.

        Args:
            trainer: The Trainer instance.
            pl_module: The LightningModule instance.
            prediction: The prediction made by the model.
            batch_indices: The indices of the batch.
            batch: The batch data.
            batch_idx: The index of the batch.
            dataloader_idx: The index of the dataloader.
        """
        # this will create N (num processes) files in `output_dir` each containing
        # the predictions of it's respective rank
        result_path = os.path.join(self.output_dir, f"predictions__rank_{trainer.global_rank}__batch_{batch_idx}.pt")

        # batch_indices is not captured due to a lightning bug when return_predictions = False
        # we use input IDs in the prediction to map the result to input
        torch.save(prediction, result_path)
        logging.info(f"Inference predictions are stored in {result_path}\n{prediction.keys()}")

    def write_on_epoch_end(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        predictions: Any,
        batch_indices: Sequence[int],
    ) -> None:
        """Writes predictions to disk at the end of each epoch.

        Args:
            trainer: The Trainer instance.
            pl_module: The LightningModule instance.
            predictions: The predictions made by the model.
            batch_indices: The indices of the batch.
        """
        # this will create N (num processes) files in `output_dir` each containing
        # the predictions of it's respective rank
        result_path = os.path.join(self.output_dir, f"predictions__rank_{trainer.global_rank}.pt")

        # collate multiple batches / ignore empty ones
        prediction = batch_collator([item for item in predictions if item is not None])

        # batch_indices is not captured due to a lightning bug when return_predictions = False
        # we use input IDs in the prediction to map the result to input
        torch.save(prediction, result_path)
        logging.info(f"Inference predictions are stored in {result_path}\n{prediction.keys()}")

__init__(output_dir, write_interval)

Initializes the callback.

Parameters:

Name Type Description Default
output_dir str | PathLike

The directory where predictions will be written.

required
write_interval IntervalT

The interval at which predictions will be written. (batch, epoch)

required
Source code in bionemo/llm/utils/callbacks.py
34
35
36
37
38
39
40
41
42
43
def __init__(self, output_dir: str | os.PathLike, write_interval: IntervalT):
    """Initializes the callback.

    Args:
        output_dir: The directory where predictions will be written.
        write_interval: The interval at which predictions will be written. (batch, epoch)

    """
    super().__init__(write_interval)
    self.output_dir = str(output_dir)

write_on_batch_end(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx)

Writes predictions to disk at the end of each batch.

Parameters:

Name Type Description Default
trainer Trainer

The Trainer instance.

required
pl_module LightningModule

The LightningModule instance.

required
prediction Any

The prediction made by the model.

required
batch_indices Sequence[int]

The indices of the batch.

required
batch Any

The batch data.

required
batch_idx int

The index of the batch.

required
dataloader_idx int

The index of the dataloader.

required
Source code in bionemo/llm/utils/callbacks.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def write_on_batch_end(
    self,
    trainer: pl.Trainer,
    pl_module: pl.LightningModule,
    prediction: Any,
    batch_indices: Sequence[int],
    batch: Any,
    batch_idx: int,
    dataloader_idx: int,
) -> None:
    """Writes predictions to disk at the end of each batch.

    Args:
        trainer: The Trainer instance.
        pl_module: The LightningModule instance.
        prediction: The prediction made by the model.
        batch_indices: The indices of the batch.
        batch: The batch data.
        batch_idx: The index of the batch.
        dataloader_idx: The index of the dataloader.
    """
    # this will create N (num processes) files in `output_dir` each containing
    # the predictions of it's respective rank
    result_path = os.path.join(self.output_dir, f"predictions__rank_{trainer.global_rank}__batch_{batch_idx}.pt")

    # batch_indices is not captured due to a lightning bug when return_predictions = False
    # we use input IDs in the prediction to map the result to input
    torch.save(prediction, result_path)
    logging.info(f"Inference predictions are stored in {result_path}\n{prediction.keys()}")

write_on_epoch_end(trainer, pl_module, predictions, batch_indices)

Writes predictions to disk at the end of each epoch.

Parameters:

Name Type Description Default
trainer Trainer

The Trainer instance.

required
pl_module LightningModule

The LightningModule instance.

required
predictions Any

The predictions made by the model.

required
batch_indices Sequence[int]

The indices of the batch.

required
Source code in bionemo/llm/utils/callbacks.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def write_on_epoch_end(
    self,
    trainer: pl.Trainer,
    pl_module: pl.LightningModule,
    predictions: Any,
    batch_indices: Sequence[int],
) -> None:
    """Writes predictions to disk at the end of each epoch.

    Args:
        trainer: The Trainer instance.
        pl_module: The LightningModule instance.
        predictions: The predictions made by the model.
        batch_indices: The indices of the batch.
    """
    # this will create N (num processes) files in `output_dir` each containing
    # the predictions of it's respective rank
    result_path = os.path.join(self.output_dir, f"predictions__rank_{trainer.global_rank}.pt")

    # collate multiple batches / ignore empty ones
    prediction = batch_collator([item for item in predictions if item is not None])

    # batch_indices is not captured due to a lightning bug when return_predictions = False
    # we use input IDs in the prediction to map the result to input
    torch.save(prediction, result_path)
    logging.info(f"Inference predictions are stored in {result_path}\n{prediction.keys()}")