Training an LLM¶
Info
This pipeline starting script is nemo_skills/pipeline/train.py
All extra parameters are passed to either nemo_skills/training/start_sft.py or nemo_skills/training/start_dpo.py
Preparing the data¶
Before running the training we need to prepare the data in the right format. Here is an example command
python -m nemo_skills.training.prepare_sft_data \
++input_files="<path to the generated synthetic data>/output-rs*.jsonl"> \
++output_path=sft-data.jsonl \
++prompt_config=generic/math \
++prompt_template=llama3-instruct
Tip
Many scripts access ++input_files
argument. You can use any glob patterns there and also
reference multiple files/patterns separated by space or comma.
If you want to run that command inside container or on cluster, add ns run_cmd --cluster=...
in the beginning.
You need to pass in the config/template files so that we can format the data accordingly. There are many more parameters that data preparation script supports which you can see here. We are using SDP library for preparing the data, so it's a good idea to check their documentation to understand how this config is structured.
Note
Even though we support both SFT and DPO training, the data preparation is currently only implemented for SFT jobs. For DPO, you'd need to manually prepare the data according to the NeMo-Aligner documentation. We will add a proper support for DPO data preparation in the near future.
Running training¶
We use NeMo-Aligner to run LLM training, so you can check their documentation to learn about all supported parameters.
Here is an example of how to run a training job.
ns train \
--cluster=slurm \
--expname=my-training-job \
--output_dir=/workspace/my-training-job/checkpoints \
--nemo_model=/nemo_models/llama3.1-8b-base \
--num_nodes=8 \
--num_gpus=8 \
--num_training_jobs=4 \
--training_data=/data/sft-data.jsonl
This will run training on 8 nodes of 8 GPUs, using 4 dependent slurm jobs. By default we are training for 2 epochs, saving checkpoints every 1000 steps, but you can adjust these values. It's also recommended to tune micro batch size and tensor parallel parameters for optimal performance. E.g. these are good defaults for an 8B model size
You can customize any of the SFT parameters by directly providing them, e.g. to disable wandb logging and add dropout use
--disable_wandb \
++model.ffn_dropout=0.1 \
++model.attention_dropout=0.1 \
++model.hidden_dropout=0.1
The training script will average all of your generated checkpoints upon completion
(we found this to consistently increase the downstream accuracy). If you want to
only average a subset of checkpoint, add --average_steps
parameter (e.g. if you
want to disable averaging, set it to the last training step). If you only want
to average the checkpoints of the finished job, set --num_training_jobs=0
.
Typically after training we want to follow up with evaluation. You can schedule
an evaluation job right away by providing a --run_after=my-training-job
argument
which will appropriately set slurm dependencies.
ns eval \
--cluster=slurm \
--model=/workspace/my-training-job/checkpoints/model-averaged-nemo \
--server_type=nemo \
--output_dir=/workspace/my-training-job/results/ \
--benchmarks gsm8k:0,math:0 \
--server_gpus=8 \
--run_after=my-training-job \
++prompt_template=llama3-instruct \
++batch_size=512
Chaining pipelines with Python¶
In general we don't recommend to run inference using NeMo checkpoints as it is much slower than other server formats. Here is how you can chain the commands to schedule checkpoint conversion and evaluation after training (whenever you need to run multiple commands, it's more convenient to use python interface)
from nemo_skills.pipeline import wrap_arguments
from nemo_skills.pipeline.cli import train, convert, eval
expname = "my-training-job"
cluster = "slurm"
output_dir = f"/workspace/{expname}/checkpoints"
train(
ctx=wrap_arguments(""),
cluster=cluster,
expname=expname,
output_dir=output_dir,
nemo_model="/nemo_models/llama3.1-8b-base",
num_nodes=8,
num_gpus=8,
num_training_jobs=4,
training_data="/data/sft-data.jsonl",
)
convert(
ctx=wrap_arguments(""),
cluster=cluster,
input_model=f"{output_dir}/model-averaged-nemo",
output_model=f"{output_dir}/model-averaged-hf",
expname=f"{expname}-to-hf",
run_after=expname,
convert_from="nemo",
convert_to="hf",
model_type="llama",
num_gpus=8,
hf_model_name="meta-llama/Meta-Llama-3.1-8B",
)
convert(
ctx=wrap_arguments(""),
cluster=cluster,
input_model=f"{output_dir}/model-averaged-hf",
output_model=f"{output_dir}/model-averaged-trtllm",
expname=f"{expname}-to-trtllm",
run_after=f"{expname}-to-hf",
convert_from="hf",
convert_to="trtllm",
model_type="llama",
num_gpus=8,
)
eval(
ctx=wrap_arguments("++prompt_template=llama3-instruct ++batch_size=512"),
cluster=cluster,
model=f"{output_dir}/model-averaged-trtllm",
server_type="trtllm",
output_dir=f"{output_dir}/results/",
benchmarks="gsm8k:0,math:0",
server_gpus=8,
run_after=f"{expname}-to-trtllm",
)