Skip to content

Pretrain mnist

run_pretrain(name, directory_name)

Run the pretraining step.

Parameters:

Name Type Description Default
name str

The experiment name.

required
directory_name str

The directory to write the output

required

Returns: str: the path of the trained model.

Source code in bionemo/example_model/training_scripts/pretrain_mnist.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def run_pretrain(name: str, directory_name: str):
    """Run the pretraining step.

    Args:
        name: The experiment name.
        directory_name: The directory to write the output
    Returns:
        str: the path of the trained model.
    """
    # Setup the logger train the model
    save_dir = Path(directory_name) / "pretrain"

    nemo_logger = NeMoLogger(
        log_dir=str(save_dir),
        name=name,
        tensorboard=TensorBoardLogger(save_dir=save_dir, name=name),
        ckpt=checkpoint_callback,
        extra_loggers=[CSVLogger(save_dir / "logs", name=name)],
    )

    # Set up the training module
    lightning_module = BionemoLightningModule(config=PretrainConfig())
    strategy = nl.MegatronStrategy(
        tensor_model_parallel_size=1,
        pipeline_model_parallel_size=1,
        ddp="megatron",
        find_unused_parameters=True,
        always_save_context=True,
    )

    trainer = nl.Trainer(
        accelerator="gpu",
        devices=1,
        strategy=strategy,
        limit_val_batches=5,
        val_check_interval=5,
        max_steps=100,
        max_epochs=10,
        num_nodes=1,
        log_every_n_steps=5,
        plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
    )

    # This trains the model
    llm.train(
        model=lightning_module,
        data=data_module,
        trainer=trainer,
        log=nemo_logger,
        resume=resume.AutoResume(
            resume_if_exists=True,  # Looks for the -last checkpoint to continue training.
            resume_ignore_no_checkpoint=True,  # When false this will throw an error with no existing checkpoint.
        ),
    )
    return Path(checkpoint_callback.last_model_path.replace(".ckpt", ""))