Skip to content

Train amplify

main(num_nodes=1, devices=1, min_seq_length=512, max_seq_length=512, result_dir=Path('./results'), num_steps=1000000, warmup_steps=1000, decay_steps=900000, limit_val_batches=1.0, val_check_interval=10000, log_every_n_steps=100, num_dataset_workers=27, biobert_spec_option=BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec, lr=0.001, micro_batch_size=64, accumulate_grad_batches=1, experiment_name='amplify', resume_if_exists=False, precision='bf16-mixed', wandb_entity=None, wandb_project=None, wandb_offline=False, wandb_tags=None, wandb_group=None, wandb_id=None, wandb_anonymous=False, wandb_log_model=False, pipeline_model_parallel_size=1, tensor_model_parallel_size=1, create_tensorboard_logger=False, nemo1_init_path=None, restore_from_checkpoint_path=None, save_last_checkpoint=True, metric_to_monitor_for_checkpoints='val_loss', save_top_k=2, nsys_profiling=False, nsys_start_step=0, nsys_end_step=None, nsys_ranks=[0], random_mask_strategy=RandomMaskStrategy.ALL_TOKENS, num_layers=24, hidden_size=640, num_attention_heads=10, ffn_hidden_size=2560, no_overlap_grad_reduce=False, overlap_param_gather=False, no_average_in_collective=False, grad_reduce_in_fp32=False)

Train an AMPLIFY model on UR100P data.

Parameters:

Name Type Description Default
num_nodes int

Number of nodes to run on

1
devices int

number of devices

1
min_seq_length Optional[int]

Whether to pad sequences to a minimum length. If None, no extra padding is added

512
max_seq_length int

The maximum sequence length for the AMPLIFY transformer

512
result_dir Path

directory to store results, logs and checkpoints

Path('./results')
num_steps int

number of steps to train the model for

1000000
warmup_steps int

number of steps for the learning rate warmup phase

1000
decay_steps int

number of steps for the learning rate decay phase

900000
limit_val_batches int

limit the number of validation global batches to this many

1.0
val_check_interval int

number of steps to periodically check the validation loss and save

10000
log_every_n_steps Optional[int]

frequency for logging (steps)

100
num_dataset_workers int

num dataset workers

27
biobert_spec_option BiobertSpecOption

the biobert spec option (architecture) to use for this run

esm2_bert_layer_with_transformer_engine_spec
lr float

learning rate

0.001
micro_batch_size int

micro batch size, from this and parallelism settings we infer the global batch size

64
accumulate_grad_batches int

number of batches to accumulate before performing a gradient update

1
experiment_name str

experiment name, this is the name used for the wandb run, and the sub-directory of the result_dir that stores the logs and checkpoints.

'amplify'
resume_if_exists bool

attempt to resume if the checkpoint exists [FIXME @skothenhill this doesn't work yet]

False
precision PrecisionTypes

precision to use for training (bf16-mixed, 16-mixed, 32)

'bf16-mixed'
wandb_entity str

The team posting this run (default: your username or your default team)

None
wandb_project str

The name of the project to which this run will belong.

None
wandb_tags List[str]

Tags associated with this run.

None
wandb_group str

A unique string shared by all runs in a given group

None
wandb_offline bool

Run offline (data can be streamed later to wandb servers).

False
wandb_id str

Sets the version, mainly used to resume a previous run.

None
wandb_anonymous bool

Enables or explicitly disables anonymous logging.

False
wandb_log_model bool

Save checkpoints in wandb dir to upload on W&B servers.

False
pipeline_model_parallel_size int

degree of pipeline model parallelism

1
tensor_model_parallel_size int

degree of tensor model parallelism

1
create_tensorboard_logger bool

create the tensorboard logger

False
nemo1_init_path Optional[Path]

path to a NeMo v1 checkpoint to initialize from

None
restore_from_checkpoint_path Optional[str]

If set, restores the model from the directory passed in. Expects the checkpoint to be created by using the ModelCheckpoint class and always_save_context=True.

None
save_last_checkpoint bool

whether to save the last checkpoint

True
metric_to_monitor_for_checkpoints str

metric to monitor for checkpoints

'val_loss'
save_top_k int

number of top checkpoints to save

2
nsys_profiling bool

whether to enable nsys profiling

False
nsys_start_step int

start step for nsys profiling

0
nsys_end_step Optional[int]

end step for nsys profiling

None
nsys_ranks List[int]

ranks for nsys profiling

[0]
random_mask_strategy RandomMaskStrategy

random mask strategy

ALL_TOKENS
num_layers int

number of layers

24
hidden_size int

hidden size

640
num_attention_heads int

number of attention heads

10
ffn_hidden_size int

feed forward hidden size

2560
no_overlap_grad_reduce bool

disable overlap gradient reduction

False
overlap_param_gather bool

overlap parameter gather

False
no_average_in_collective bool

disable average in collective

False
grad_reduce_in_fp32 bool

gradient reduction in fp32

False
Source code in bionemo/amplify/train_amplify.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
@app.command()
def main(
    num_nodes: int = 1,
    devices: int = 1,
    min_seq_length: Optional[int] = 512,
    max_seq_length: int = 512,
    result_dir: Path = Path("./results"),
    num_steps: int = 1_000_000,
    warmup_steps: int = 1000,
    decay_steps: int = 900_000,
    limit_val_batches: float = 1.0,
    val_check_interval: int = 10000,
    log_every_n_steps: Optional[int] = 100,
    num_dataset_workers: int = 27,
    biobert_spec_option: BiobertSpecOption = BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec,
    lr: float = 1e-3,
    micro_batch_size: int = 64,
    accumulate_grad_batches: int = 1,
    experiment_name: str = "amplify",
    resume_if_exists: bool = False,
    precision: str = "bf16-mixed",
    wandb_entity: Optional[str] = None,
    wandb_project: Optional[str] = None,
    wandb_offline: bool = False,
    wandb_tags: Optional[List[str]] = None,
    wandb_group: Optional[str] = None,
    wandb_id: Optional[str] = None,
    wandb_anonymous: bool = False,
    wandb_log_model: bool = False,
    pipeline_model_parallel_size: int = 1,
    tensor_model_parallel_size: int = 1,
    create_tensorboard_logger: bool = False,
    nemo1_init_path: Optional[Path] = None,
    restore_from_checkpoint_path: Optional[str] = None,
    save_last_checkpoint: bool = True,
    metric_to_monitor_for_checkpoints: str = "val_loss",
    save_top_k: int = 2,
    nsys_profiling: bool = False,
    nsys_start_step: int = 0,
    nsys_end_step: Optional[int] = None,
    nsys_ranks: List[int] = [0],
    random_mask_strategy: RandomMaskStrategy = RandomMaskStrategy.ALL_TOKENS,
    num_layers: int = 24,
    hidden_size: int = 640,
    num_attention_heads: int = 10,
    ffn_hidden_size: int = 2560,
    no_overlap_grad_reduce: bool = False,
    overlap_param_gather: bool = False,
    no_average_in_collective: bool = False,
    grad_reduce_in_fp32: bool = False,
) -> nl.Trainer:
    """Train an AMPLIFY model on UR100P data.

    Args:
        num_nodes (int): Number of nodes to run on
        devices (int): number of devices
        min_seq_length (Optional[int]): Whether to pad sequences to a minimum length. If None, no extra padding is added
        max_seq_length (int): The maximum sequence length for the AMPLIFY transformer
        result_dir (Path): directory to store results, logs and checkpoints
        num_steps (int): number of steps to train the model for
        warmup_steps (int): number of steps for the learning rate warmup phase
        decay_steps (int): number of steps for the learning rate decay phase
        limit_val_batches (int): limit the number of validation global batches to this many
        val_check_interval (int): number of steps to periodically check the validation loss and save
        log_every_n_steps (Optional[int]): frequency for logging (steps)
        num_dataset_workers (int): num dataset workers
        biobert_spec_option (BiobertSpecOption): the biobert spec option (architecture) to use for this run
        lr (float): learning rate
        micro_batch_size (int): micro batch size, from this and parallelism settings we infer the global batch size
        accumulate_grad_batches (int): number of batches to accumulate before performing a gradient update
        experiment_name (str): experiment name, this is the name used for the wandb run, and the sub-directory of the
            result_dir that stores the logs and checkpoints.
        resume_if_exists (bool): attempt to resume if the checkpoint exists [FIXME @skothenhill this doesn't work yet]
        precision (PrecisionTypes): precision to use for training (bf16-mixed, 16-mixed, 32)
        wandb_entity (str): The team posting this run (default: your username or your default team)
        wandb_project (str): The name of the project to which this run will belong.
        wandb_tags (List[str]): Tags associated with this run.
        wandb_group (str): A unique string shared by all runs in a given group
        wandb_offline (bool): Run offline (data can be streamed later to wandb servers).
        wandb_id (str): Sets the version, mainly used to resume a previous run.
        wandb_anonymous (bool): Enables or explicitly disables anonymous logging.
        wandb_log_model (bool): Save checkpoints in wandb dir to upload on W&B servers.
        pipeline_model_parallel_size (int): degree of pipeline model parallelism
        tensor_model_parallel_size (int): degree of tensor model parallelism
        create_tensorboard_logger (bool): create the tensorboard logger
        nemo1_init_path (Optional[Path]): path to a NeMo v1 checkpoint to initialize from
        restore_from_checkpoint_path (Optional[str]): If set, restores the model from the directory passed in. Expects the
            checkpoint to be created by using the ModelCheckpoint class and always_save_context=True.
        save_last_checkpoint (bool): whether to save the last checkpoint
        metric_to_monitor_for_checkpoints (str): metric to monitor for checkpoints
        save_top_k (int): number of top checkpoints to save
        nsys_profiling (bool): whether to enable nsys profiling
        nsys_start_step (int): start step for nsys profiling
        nsys_end_step (Optional[int]): end step for nsys profiling
        nsys_ranks (List[int]): ranks for nsys profiling
        random_mask_strategy (RandomMaskStrategy): random mask strategy
        num_layers (int): number of layers
        hidden_size (int): hidden size
        num_attention_heads (int): number of attention heads
        ffn_hidden_size (int): feed forward hidden size
        no_overlap_grad_reduce (bool): disable overlap gradient reduction
        overlap_param_gather (bool): overlap parameter gather
        no_average_in_collective (bool): disable average in collective
        grad_reduce_in_fp32 (bool): gradient reduction in fp32
    """
    # Create the result directory if it does not exist.
    result_dir.mkdir(parents=True, exist_ok=True)

    # Setup the strategy and trainer
    global_batch_size = infer_global_batch_size(
        micro_batch_size=micro_batch_size,
        num_nodes=num_nodes,
        devices=devices,
        accumulate_grad_batches=accumulate_grad_batches,
        tensor_model_parallel_size=tensor_model_parallel_size,
        pipeline_model_parallel_size=pipeline_model_parallel_size,
    )

    strategy = nl.MegatronStrategy(
        tensor_model_parallel_size=tensor_model_parallel_size,
        pipeline_model_parallel_size=pipeline_model_parallel_size,
        pipeline_dtype=get_autocast_dtype(precision),
        ddp=DistributedDataParallelConfig(
            check_for_nan_in_grad=True,
            overlap_grad_reduce=not no_overlap_grad_reduce,
            overlap_param_gather=overlap_param_gather,
            average_in_collective=not no_average_in_collective,
            grad_reduce_in_fp32=grad_reduce_in_fp32,
            use_distributed_optimizer=True,
        ),
        find_unused_parameters=True,
        gradient_as_bucket_view=True,
        ckpt_include_optimizer=True,
        ckpt_async_save=True,
        ckpt_parallel_load=True,
    )

    # for wandb integration
    # Please refer to https://pytorch-lightning.readthedocs.io/en/0.7.6/api/pytorch_lightning.loggers.html"
    wandb_config: Optional[WandbConfig] = (
        None
        if wandb_project is None
        else WandbConfig(
            offline=wandb_offline,
            project=wandb_project,
            entity=wandb_entity,
            tags=wandb_tags,
            group=wandb_group,
            id=wandb_id,
            anonymous=wandb_anonymous,
            log_model=wandb_log_model,
        )
    )

    callbacks = [
        RichModelSummary(max_depth=4),
        LearningRateMonitor(),
        nl_callbacks.PreemptionCallback(),
        TimingCallback(),
    ]
    if nsys_profiling:
        if nsys_end_step is None:
            nsys_end_step = num_steps
        callbacks.append(
            nl_callbacks.NsysCallback(
                start_step=nsys_start_step, end_step=nsys_end_step, ranks=nsys_ranks, gen_shape=True
            )
        )

    trainer = nl.Trainer(
        devices=devices,
        max_steps=num_steps,
        accelerator="gpu",
        strategy=strategy,
        limit_val_batches=limit_val_batches,  # This controls upsampling and downsampling
        val_check_interval=val_check_interval,
        log_every_n_steps=log_every_n_steps,
        num_nodes=num_nodes,
        callbacks=callbacks,
        plugins=nl.MegatronMixedPrecision(
            precision=precision,
            params_dtype=get_autocast_dtype(precision),
            pipeline_dtype=get_autocast_dtype(precision),
            grad_reduce_in_fp32=grad_reduce_in_fp32,
            autocast_enabled=False,
        ),
    )

    tokenizer = BioNeMoAMPLIFYTokenizer()

    # Initialize the data module.
    data = AMPLIFYDataModule(
        train_hf_dataset=hf_load_dataset("chandar-lab/UR100P", data_dir="UniProt", split="train"),  # type: ignore
        valid_hf_dataset=hf_load_dataset("chandar-lab/UR100P", data_dir="UniProt", split="test"),  # type: ignore
        global_batch_size=global_batch_size,
        micro_batch_size=micro_batch_size,
        min_seq_length=min_seq_length,
        max_seq_length=max_seq_length,
        num_workers=num_dataset_workers,
        random_mask_strategy=random_mask_strategy,
        tokenizer=tokenizer,
    )

    # Configure the model
    train_metric = None
    is_model_parallel = tensor_model_parallel_size * pipeline_model_parallel_size > 1
    if is_model_parallel:
        valid_metric = None  # metric logging under model parallelism is not supported yet
    else:
        valid_metric = TorchmetricsConfig(
            class_path="text.Perplexity",
            task="pretraining",
            kwargs={"ignore_index": MLM_LOSS_IGNORE_INDEX},
            metric_name="val_ppl",
        )

    amplify_config = AMPLIFYConfig(
        seq_length=max_seq_length,
        num_layers=num_layers,
        hidden_size=hidden_size,
        num_attention_heads=num_attention_heads,
        ffn_hidden_size=ffn_hidden_size,
        params_dtype=get_autocast_dtype(precision),
        pipeline_dtype=get_autocast_dtype(precision),
        autocast_dtype=get_autocast_dtype(precision),  # setting this speeds things up a lot
        biobert_spec_option=biobert_spec_option,
        nemo1_ckpt_path=str(nemo1_init_path) if nemo1_init_path is not None else None,
        # handle checkpoint resumption here rather than auto-resume so this supports fine-tuning capabilities
        initial_ckpt_path=str(restore_from_checkpoint_path) if restore_from_checkpoint_path is not None else None,
        variable_seq_lengths=min_seq_length != max_seq_length,
        train_metric=train_metric,
        valid_metric=valid_metric,
    )

    model = biobert_lightning_module(
        amplify_config,
        tokenizer=tokenizer,
        optimizer=MegatronOptimizerModule(
            config=OptimizerConfig(
                lr=lr,
                optimizer="adam",  # fused_adam not supported
                use_distributed_optimizer=True,
                weight_decay=0.01,
                adam_beta1=0.9,
                adam_beta2=0.95,
                clip_grad=1.0,
            ),
            lr_scheduler=nl.lr_scheduler.CosineAnnealingScheduler(
                min_lr=0.1 * lr,
                max_steps=decay_steps,
                warmup_steps=warmup_steps,
                constant_steps=0,
            ),
        ),
    )

    # Configure our custom Checkpointer
    checkpoint_callback = nl_callbacks.ModelCheckpoint(
        save_last=save_last_checkpoint,
        monitor=metric_to_monitor_for_checkpoints,  # "val_loss",
        save_top_k=save_top_k,
        every_n_train_steps=val_check_interval,
        always_save_context=True,  # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
        filename="{epoch}-{val_loss:.2f}-{step}-{consumed_samples}",  # Including step and consumed_samples in the checkpoint filename prevents duplicate filenames and bugs related to this.
    )

    # Setup the logger and train the model
    nemo_logger = setup_nemo_lightning_logger(
        root_dir=result_dir,
        name=experiment_name,
        initialize_tensorboard_logger=create_tensorboard_logger,
        wandb_config=wandb_config,
        ckpt_callback=checkpoint_callback,
    )

    llm.train(
        model=model,
        data=data,
        trainer=trainer,
        log=nemo_logger,
        resume=resume.AutoResume(
            resume_if_exists=resume_if_exists,  # Looks for the -last checkpoint to continue training.
            resume_ignore_no_checkpoint=True,  # When false this will throw an error with no existing checkpoint.
        ),
    )

    return trainer