Skip to content

Callbacks

GeneformerPredictionWriter

Bases: BasePredictionWriter, Callback

A callback that writes predictions to disk at specified intervals during training.

Source code in bionemo/geneformer/utils/callbacks.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
class GeneformerPredictionWriter(BasePredictionWriter, pl.Callback):
    """A callback that writes predictions to disk at specified intervals during training."""

    def __init__(
        self,
        output_dir: str | os.PathLike,
        write_interval: IntervalT,
        tokenizer: GeneTokenizer,
        batch_dim_key_defaults: dict[str, int] | None = None,
        seq_dim_key_defaults: dict[str, int] | None = None,
        include_gene_embeddings: bool = False,
    ):
        """Initializes the callback.

        Args:
            output_dir: The directory where predictions will be written.
            write_interval: The interval at which predictions will be written. (batch, epoch)
            tokenizer: The GeneTokenizer instance for mapping input_ids to gene names, and filtering out special tokens.
            batch_dim_key_defaults: The default batch dimension for each key, if different from the standard 0.
            seq_dim_key_defaults: The default sequence dimension for each key, if different from the standard 1.
            include_gene_embeddings: Whether to include gene embeddings in the output predictions.
        """
        super().__init__(write_interval)
        self.output_dir = str(output_dir)
        self.include_gene_embeddings = include_gene_embeddings
        self.batch_dim_key_defaults = batch_dim_key_defaults
        self.seq_dim_key_defaults = seq_dim_key_defaults
        self.tokenizer = tokenizer

    def write_on_epoch_end(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        predictions: Any,
        batch_indices: Sequence[int],
    ) -> None:
        """Writes predictions to disk at the end of each epoch.

        Writing all predictions on epoch end is memory intensive. It is recommended to use the batch writer instead for
        large predictions.

        Multi-device predictions will likely yield predictions in an order that is inconsistent with single device predictions and the input data.

        Args:
            trainer: The Trainer instance.
            pl_module: The LightningModule instance, required by PyTorch Lightning.
            predictions: The predictions made by the model.
            batch_indices: The indices of the batch, required by PyTorch Lightning.

        Raises:
            Multi-GPU predictions are output in an inconsistent order with multiple devices.
        """
        # this will create N (num processes) files in `output_dir` each containing
        # the predictions of it's respective rank

        result_path = os.path.join(self.output_dir, f"predictions__rank_{trainer.global_rank}.pt")

        # collate multiple batches / ignore empty ones
        collate_kwargs = {}
        if self.batch_dim_key_defaults is not None:
            collate_kwargs["batch_dim_key_defaults"] = self.batch_dim_key_defaults
        if self.seq_dim_key_defaults is not None:
            collate_kwargs["seq_dim_key_defaults"] = self.seq_dim_key_defaults

        prediction = batch_collator([item for item in predictions if item is not None], **collate_kwargs)

        # batch_indices is not captured due to a lightning bug when return_predictions = False
        # we use input IDs in the prediction to map the result to input

        if self.include_gene_embeddings and "input_ids" in prediction and "hidden_states" in prediction:
            hidden_states = prediction["hidden_states"]
            input_ids = prediction["input_ids"]

            logging.info("Calculating gene embeddings.")
            logging.info(f"hidden_states: {hidden_states.shape[:2]}; input_ids: {input_ids.shape[:2]}")
            assert hidden_states.shape[:2] == input_ids.shape[:2]

            # accumulators for calculating mean embedding for each input_id
            gene_embedding_accumulator = {}
            input_id_count = {}
            ensembl_IDs = {}
            gene_symbols = {}

            # iterate over all cells
            cell_count = len(input_ids)
            for i in range(cell_count):
                cell_state = hidden_states[i]
                cell_input_ids = input_ids[i].cpu().numpy()

                # iterate over each gene in the cell
                for idx, embedding in zip(cell_input_ids, cell_state):
                    # skip calculation for special tokens like [CLS], [SEP], [PAD], [MASK], [UKW]
                    if idx in self.tokenizer.all_special_ids:
                        continue

                    # accumulate embedding sum and count
                    if idx not in gene_embedding_accumulator:
                        # initialize embedding sum with first found embedding
                        gene_embedding_accumulator[idx] = embedding

                        # increment input_id count
                        input_id_count[idx] = 1
                    else:
                        # accumulate embedding sum
                        gene_embedding_accumulator[idx] += embedding

                        # increment input_id count
                        input_id_count[idx] += 1

            # divide each embedding sum by the total occurences of each gene to get an average
            for input_id in gene_embedding_accumulator.keys():
                gene_embedding_accumulator[input_id] /= input_id_count[input_id]
                # map input_ids to gene symbols and ensembl IDs
                ensembl_IDs[input_id] = self.tokenizer.decode_vocab.get(input_id, f"UNKNOWN_{input_id}")
                gene_symbols[input_id] = self.tokenizer.ens_to_gene.get(ensembl_IDs[input_id], f"UNKNOWN_{input_id}")

            logging.info(f"Number of unique gene embeddings: {len(gene_embedding_accumulator)}")
            logging.info("Finished calculating gene embeddings.")

            prediction["gene_embeddings"] = gene_embedding_accumulator
            prediction["gene_counts"] = input_id_count
            prediction["ensembl_IDs"] = ensembl_IDs
            prediction["gene_symbols"] = gene_symbols

        torch.save(prediction, result_path)
        if isinstance(prediction, dict):
            keys = prediction.keys()
        else:
            keys = "tensor"
        logging.info(f"Inference predictions are stored in {result_path}\n{keys}")

__init__(output_dir, write_interval, tokenizer, batch_dim_key_defaults=None, seq_dim_key_defaults=None, include_gene_embeddings=False)

Initializes the callback.

Parameters:

Name Type Description Default
output_dir str | PathLike

The directory where predictions will be written.

required
write_interval IntervalT

The interval at which predictions will be written. (batch, epoch)

required
tokenizer GeneTokenizer

The GeneTokenizer instance for mapping input_ids to gene names, and filtering out special tokens.

required
batch_dim_key_defaults dict[str, int] | None

The default batch dimension for each key, if different from the standard 0.

None
seq_dim_key_defaults dict[str, int] | None

The default sequence dimension for each key, if different from the standard 1.

None
include_gene_embeddings bool

Whether to include gene embeddings in the output predictions.

False
Source code in bionemo/geneformer/utils/callbacks.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def __init__(
    self,
    output_dir: str | os.PathLike,
    write_interval: IntervalT,
    tokenizer: GeneTokenizer,
    batch_dim_key_defaults: dict[str, int] | None = None,
    seq_dim_key_defaults: dict[str, int] | None = None,
    include_gene_embeddings: bool = False,
):
    """Initializes the callback.

    Args:
        output_dir: The directory where predictions will be written.
        write_interval: The interval at which predictions will be written. (batch, epoch)
        tokenizer: The GeneTokenizer instance for mapping input_ids to gene names, and filtering out special tokens.
        batch_dim_key_defaults: The default batch dimension for each key, if different from the standard 0.
        seq_dim_key_defaults: The default sequence dimension for each key, if different from the standard 1.
        include_gene_embeddings: Whether to include gene embeddings in the output predictions.
    """
    super().__init__(write_interval)
    self.output_dir = str(output_dir)
    self.include_gene_embeddings = include_gene_embeddings
    self.batch_dim_key_defaults = batch_dim_key_defaults
    self.seq_dim_key_defaults = seq_dim_key_defaults
    self.tokenizer = tokenizer

write_on_epoch_end(trainer, pl_module, predictions, batch_indices)

Writes predictions to disk at the end of each epoch.

Writing all predictions on epoch end is memory intensive. It is recommended to use the batch writer instead for large predictions.

Multi-device predictions will likely yield predictions in an order that is inconsistent with single device predictions and the input data.

Parameters:

Name Type Description Default
trainer Trainer

The Trainer instance.

required
pl_module LightningModule

The LightningModule instance, required by PyTorch Lightning.

required
predictions Any

The predictions made by the model.

required
batch_indices Sequence[int]

The indices of the batch, required by PyTorch Lightning.

required
Source code in bionemo/geneformer/utils/callbacks.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def write_on_epoch_end(
    self,
    trainer: pl.Trainer,
    pl_module: pl.LightningModule,
    predictions: Any,
    batch_indices: Sequence[int],
) -> None:
    """Writes predictions to disk at the end of each epoch.

    Writing all predictions on epoch end is memory intensive. It is recommended to use the batch writer instead for
    large predictions.

    Multi-device predictions will likely yield predictions in an order that is inconsistent with single device predictions and the input data.

    Args:
        trainer: The Trainer instance.
        pl_module: The LightningModule instance, required by PyTorch Lightning.
        predictions: The predictions made by the model.
        batch_indices: The indices of the batch, required by PyTorch Lightning.

    Raises:
        Multi-GPU predictions are output in an inconsistent order with multiple devices.
    """
    # this will create N (num processes) files in `output_dir` each containing
    # the predictions of it's respective rank

    result_path = os.path.join(self.output_dir, f"predictions__rank_{trainer.global_rank}.pt")

    # collate multiple batches / ignore empty ones
    collate_kwargs = {}
    if self.batch_dim_key_defaults is not None:
        collate_kwargs["batch_dim_key_defaults"] = self.batch_dim_key_defaults
    if self.seq_dim_key_defaults is not None:
        collate_kwargs["seq_dim_key_defaults"] = self.seq_dim_key_defaults

    prediction = batch_collator([item for item in predictions if item is not None], **collate_kwargs)

    # batch_indices is not captured due to a lightning bug when return_predictions = False
    # we use input IDs in the prediction to map the result to input

    if self.include_gene_embeddings and "input_ids" in prediction and "hidden_states" in prediction:
        hidden_states = prediction["hidden_states"]
        input_ids = prediction["input_ids"]

        logging.info("Calculating gene embeddings.")
        logging.info(f"hidden_states: {hidden_states.shape[:2]}; input_ids: {input_ids.shape[:2]}")
        assert hidden_states.shape[:2] == input_ids.shape[:2]

        # accumulators for calculating mean embedding for each input_id
        gene_embedding_accumulator = {}
        input_id_count = {}
        ensembl_IDs = {}
        gene_symbols = {}

        # iterate over all cells
        cell_count = len(input_ids)
        for i in range(cell_count):
            cell_state = hidden_states[i]
            cell_input_ids = input_ids[i].cpu().numpy()

            # iterate over each gene in the cell
            for idx, embedding in zip(cell_input_ids, cell_state):
                # skip calculation for special tokens like [CLS], [SEP], [PAD], [MASK], [UKW]
                if idx in self.tokenizer.all_special_ids:
                    continue

                # accumulate embedding sum and count
                if idx not in gene_embedding_accumulator:
                    # initialize embedding sum with first found embedding
                    gene_embedding_accumulator[idx] = embedding

                    # increment input_id count
                    input_id_count[idx] = 1
                else:
                    # accumulate embedding sum
                    gene_embedding_accumulator[idx] += embedding

                    # increment input_id count
                    input_id_count[idx] += 1

        # divide each embedding sum by the total occurences of each gene to get an average
        for input_id in gene_embedding_accumulator.keys():
            gene_embedding_accumulator[input_id] /= input_id_count[input_id]
            # map input_ids to gene symbols and ensembl IDs
            ensembl_IDs[input_id] = self.tokenizer.decode_vocab.get(input_id, f"UNKNOWN_{input_id}")
            gene_symbols[input_id] = self.tokenizer.ens_to_gene.get(ensembl_IDs[input_id], f"UNKNOWN_{input_id}")

        logging.info(f"Number of unique gene embeddings: {len(gene_embedding_accumulator)}")
        logging.info("Finished calculating gene embeddings.")

        prediction["gene_embeddings"] = gene_embedding_accumulator
        prediction["gene_counts"] = input_id_count
        prediction["ensembl_IDs"] = ensembl_IDs
        prediction["gene_symbols"] = gene_symbols

    torch.save(prediction, result_path)
    if isinstance(prediction, dict):
        keys = prediction.keys()
    else:
        keys = "tensor"
    logging.info(f"Inference predictions are stored in {result_path}\n{keys}")