Skip to content

Finetune mnist

run_finetune(checkpoint_dir, name, directory_name)

Run the finetuning step.

Parameters:

Name Type Description Default
checkpoint_dir str

The directory with the previous model

required
name str

The experiment name.

required
directory_name str

The directory to write the output

required

Returns: str: the path of the trained model.

Source code in bionemo/example_model/training_scripts/finetune_mnist.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def run_finetune(checkpoint_dir: str, name: str, directory_name: str):
    """Run the finetuning step.

    Args:
        checkpoint_dir: The directory with the previous model
        name: The experiment name.
        directory_name: The directory to write the output
    Returns:
        str: the path of the trained model.
    """
    save_dir = Path(directory_name) / "classifier"
    checkpoint_callback = nl_callbacks.ModelCheckpoint(
        save_last=True,
        save_on_train_epoch_end=True,
        monitor="val_loss",
        always_save_context=True,  # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
    )

    nemo_logger2 = NeMoLogger(
        log_dir=str(save_dir),
        name=name,
        tensorboard=TensorBoardLogger(save_dir=save_dir, name=name),
        ckpt=checkpoint_callback,
        extra_loggers=[CSVLogger(save_dir / "logs", name=name)],
    )

    lightning_module2 = BionemoLightningModule(
        config=ExampleFineTuneConfig(
            initial_ckpt_path=checkpoint_dir,
            initial_ckpt_skip_keys_with_these_prefixes={"digit_classifier"},
        )
    )

    strategy = nl.MegatronStrategy(
        tensor_model_parallel_size=1,
        pipeline_model_parallel_size=1,
        ddp="megatron",
        find_unused_parameters=True,
        always_save_context=True,
    )

    trainer = nl.Trainer(
        accelerator="gpu",
        devices=1,
        strategy=strategy,
        limit_val_batches=5,
        val_check_interval=5,
        max_steps=100,
        max_epochs=10,
        num_nodes=1,
        log_every_n_steps=5,
        plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
    )
    llm.train(
        model=lightning_module2,
        data=data_module,
        trainer=trainer,
        log=nemo_logger2,
        resume=resume.AutoResume(
            resume_if_exists=True,
            resume_ignore_no_checkpoint=True,
        ),
    )
    finetune_dir = Path(checkpoint_callback.last_model_path.replace(".ckpt", ""))
    return finetune_dir