Skip to content

Infer geneformer

geneformer_infer_entrypoint()

Entrypoint for running inference on a geneformer checkpoint and data.

Source code in bionemo/geneformer/scripts/infer_geneformer.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def geneformer_infer_entrypoint():
    """Entrypoint for running inference on a geneformer checkpoint and data."""
    # 1. get arguments
    parser = get_parser()
    args = parser.parse_args()
    # 2. Call infer with args
    infer_model(
        data_path=args.data_dir,
        checkpoint_path=args.checkpoint_path,
        results_path=args.result_path,
        include_hiddens=args.include_hiddens,
        micro_batch_size=args.micro_batch_size,
        include_embeddings=not args.no_embeddings,
        include_logits=args.include_logits,
        seq_length=args.seq_length,
        precision=args.precision,
        devices=args.num_gpus,
        num_nodes=args.num_nodes,
        num_dataset_workers=args.num_dataset_workers,
        config_class=args.config_class,
    )

get_parser()

Return the cli parser for this tool.

Source code in bionemo/geneformer/scripts/infer_geneformer.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def get_parser():
    """Return the cli parser for this tool."""
    parser = argparse.ArgumentParser(
        description="Infer sc_memmap processed single cell data with Geneformer from a checkpiont."
    )
    parser.add_argument(
        "--data-dir",
        type=Path,
        required=True,
        help="Path to the data directory, for example this might be "
        "/workspace/bionemo2/data/cellxgene_2023-12-15_small/processed_train",
    )
    parser.add_argument(
        "--checkpoint-path",
        type=Path,
        required=False,
        default=None,
        help="Path to the checkpoint directory to restore from.",
    )
    parser.add_argument(
        "--precision",
        type=str,
        choices=get_args(PrecisionTypes),
        required=False,
        default="bf16-mixed",
        help="Precision type to use for training.",
    )
    parser.add_argument("--include-hiddens", action="store_true", default=False, help="Include hiddens in output.")
    parser.add_argument("--no-embeddings", action="store_true", default=False, help="Do not output embeddings.")
    parser.add_argument(
        "--include-logits", action="store_true", default=False, help="Include per-token logits in output."
    )

    parser.add_argument(
        "--result-path", type=Path, required=False, default=Path("./results.pt"), help="Path to the result file."
    )
    parser.add_argument(
        "--num-gpus",
        type=int,
        required=False,
        default=1,
        help="Number of GPUs to use for training. Default is 1.",
    )
    parser.add_argument(
        "--num-nodes",
        type=int,
        required=False,
        default=1,
        help="Number of nodes to use for training. Default is 1.",
    )
    parser.add_argument(
        "--num-dataset-workers",
        type=int,
        required=False,
        default=0,
        help="Number of steps to use for training. Default is 0.",
    )
    parser.add_argument(
        "--seq-length",
        type=int,
        required=False,
        default=2048,
        help="Sequence length of cell. Default is 2048.",
    )
    parser.add_argument(
        "--micro-batch-size",
        type=int,
        required=False,
        default=32,
        help="Micro-batch size. Global batch size is inferred from this.",
    )

    # TODO consider whether nemo.run or some other method can simplify this config class lookup.
    config_class_options: Dict[str, Type[BioBertConfig]] = {
        "GeneformerConfig": GeneformerConfig,
        "FineTuneSeqLenBioBertConfig": FineTuneSeqLenBioBertConfig,
    }

    def config_class_type(desc: str) -> Type[BioBertConfig]:
        try:
            return config_class_options[desc]
        except KeyError:
            raise argparse.ArgumentTypeError(
                f"Do not recognize key {desc}, valid options are: {config_class_options.keys()}"
            )

    parser.add_argument(
        "--config-class",
        type=config_class_type,
        default="GeneformerConfig",
        help="Model configs link model classes with losses, and handle model initialization (including from a prior "
        "checkpoint). This is how you can fine-tune a model. First train with one config class that points to one model "
        "class and loss, then implement and provide an alternative config class that points to a variant of that model "
        "and alternative loss. In the future this script should also provide similar support for picking different data "
        f"modules for fine-tuning with different data types. Choices: {config_class_options.keys()}",
    )
    return parser

infer_model(data_path, checkpoint_path, results_path, include_hiddens=False, include_embeddings=False, include_logits=False, seq_length=2048, micro_batch_size=64, precision='bf16-mixed', tensor_model_parallel_size=1, pipeline_model_parallel_size=1, devices=1, num_nodes=1, num_dataset_workers=0, config_class=GeneformerConfig)

Inference function (requires DDP and only training data that fits in memory).

Source code in bionemo/geneformer/scripts/infer_geneformer.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def infer_model(
    data_path: Path,
    checkpoint_path: Path,
    results_path: Path,
    include_hiddens: bool = False,
    include_embeddings: bool = False,
    include_logits: bool = False,
    seq_length: int = 2048,
    micro_batch_size: int = 64,
    precision: PrecisionTypes = "bf16-mixed",
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
    devices: int = 1,
    num_nodes: int = 1,
    num_dataset_workers: int = 0,
    config_class: Type[BioBertConfig] = GeneformerConfig,
) -> None:
    """Inference function (requires DDP and only training data that fits in memory)."""
    # This is just used to get the tokenizer :(
    train_data_path: Path = (
        load("single_cell/testdata-20240506") / "cellxgene_2023-12-15_small" / "processed_data" / "train"
    )

    # Setup the strategy and trainer
    pipeline_model_parallel_size = 1
    tensor_model_parallel_size = 1
    accumulate_grad_batches = 1
    global_batch_size = infer_global_batch_size(
        micro_batch_size=micro_batch_size,
        num_nodes=num_nodes,
        devices=devices,
        accumulate_grad_batches=accumulate_grad_batches,
        tensor_model_parallel_size=tensor_model_parallel_size,
        pipeline_model_parallel_size=pipeline_model_parallel_size,
    )

    preprocessor = GeneformerPreprocess(
        download_directory=train_data_path,
        medians_file_path=train_data_path / "medians.json",
        tokenizer_vocab_path=train_data_path / "geneformer.vocab",
    )
    match preprocessor.preprocess():
        case {"tokenizer": tokenizer, "median_dict": median_dict}:
            logging.info("*************** Preprocessing Finished ************")
        case _:
            logging.error("Preprocessing failed.")

    strategy = nl.MegatronStrategy(
        tensor_model_parallel_size=tensor_model_parallel_size,
        pipeline_model_parallel_size=pipeline_model_parallel_size,
        ddp="megatron",
        find_unused_parameters=True,
        ckpt_include_optimizer=True,
        progress_interval=1,
    )
    trainer = nl.Trainer(
        devices=devices,
        accelerator="gpu",
        strategy=strategy,
        num_nodes=num_nodes,
        callbacks=[],
        plugins=nl.MegatronMixedPrecision(precision=precision),
    )
    # Configure the data module and model
    data = SingleCellDataModule(
        seq_length=seq_length,
        tokenizer=tokenizer,
        train_dataset_path=None,
        val_dataset_path=None,
        test_dataset_path=None,
        predict_dataset_path=data_path,
        mask_prob=0,
        mask_token_prob=0,
        random_token_prob=0,  # changed to represent the incorrect setting we originally used.
        median_dict=median_dict,
        micro_batch_size=micro_batch_size,
        global_batch_size=global_batch_size,
        # persistent workers is supported when num_dataset_workers > 0
        persistent_workers=num_dataset_workers > 0,
        pin_memory=False,
        num_workers=num_dataset_workers,
    )
    geneformer_config = config_class(
        seq_length=seq_length,
        params_dtype=get_autocast_dtype(precision),
        pipeline_dtype=get_autocast_dtype(precision),
        autocast_dtype=get_autocast_dtype(precision),  # setting this speeds things up a lot
        # handle checkpoint resumption here rather than auto-resume so this supports fine-tuning capabilities
        initial_ckpt_path=str(checkpoint_path) if checkpoint_path is not None else None,
        include_embeddings=include_embeddings,
        include_hiddens=include_hiddens,
        skip_logits=not include_logits,
        initial_ckpt_skip_keys_with_these_prefixes=[],  # load everything from the checkpoint.
    )
    # The lightning class owns a copy of the actual model, and a loss function, both of which are configured
    #  and lazily returned by the `geneformer_config` object defined above.
    model = biobert_lightning_module(
        geneformer_config,
        tokenizer=tokenizer,
    )

    results_dict = batch_collator(trainer.predict(model, datamodule=data, return_predictions=True))
    non_none_keys = [key for key, val in results_dict.items() if val is not None]
    print(f"Writing output {str(non_none_keys)} into {results_path}")
    torch.save(results_dict, results_path)