Remove optimizer state from an MBridge DCP checkpoint.
Reads a Megatron Bridge checkpoint (which may contain model weights, optimizer
state, LR scheduler state, and RNG state), strips everything except model
weights, and writes a new checkpoint. The result is a smaller checkpoint
suitable for release or fine-tuning.
This module depends only on PyTorch and the standard library -- it must NOT
import megatron, nemo, or mbridge.
main()
CLI entry point for removing optimizer state from an MBridge checkpoint.
Source code in bionemo/recipeutils/checkpoint/remove_optimizer.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186 | def main():
"""CLI entry point for removing optimizer state from an MBridge checkpoint."""
parser = argparse.ArgumentParser(
description="Remove optimizer state from a Megatron Bridge DCP checkpoint, "
"producing a smaller weights-only checkpoint."
)
parser.add_argument(
"--src-ckpt-dir",
type=Path,
required=True,
help="Source checkpoint directory (containing iter_NNNNNNN/)",
)
parser.add_argument(
"--dst-ckpt-dir",
type=Path,
required=True,
help="Destination directory for the weights-only checkpoint",
)
parser.add_argument("--verbose", "-v", action="store_true")
args = parser.parse_args()
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
remove_optimizer(args.src_ckpt_dir, args.dst_ckpt_dir)
|
remove_optimizer(src_ckpt_dir, dst_ckpt_dir)
Strip optimizer / scheduler / RNG state from an MBridge checkpoint.
Parameters:
| Name |
Type |
Description |
Default |
src_ckpt_dir
|
Path
|
Source checkpoint root (contains iter_NNNNNNN/).
|
required
|
dst_ckpt_dir
|
Path
|
Destination directory. Must not already exist.
|
required
|
Returns:
| Type |
Description |
Path
|
Path to the destination checkpoint directory.
|
Source code in bionemo/recipeutils/checkpoint/remove_optimizer.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161 | def remove_optimizer(
src_ckpt_dir: Path,
dst_ckpt_dir: Path,
) -> Path:
"""Strip optimizer / scheduler / RNG state from an MBridge checkpoint.
Args:
src_ckpt_dir: Source checkpoint root (contains ``iter_NNNNNNN/``).
dst_ckpt_dir: Destination directory. Must not already exist.
Returns:
Path to the destination checkpoint directory.
"""
if dst_ckpt_dir.exists():
raise FileExistsError(f"Destination already exists: {dst_ckpt_dir}")
src_iter_dir = _resolve_iter_dir(src_ckpt_dir)
logger.info(f"Source iter directory: {src_iter_dir}")
# Determine the iteration name so the destination mirrors the structure.
iter_name = src_iter_dir.name
dst_iter_dir = dst_ckpt_dir / iter_name
dst_iter_dir.mkdir(parents=True, exist_ok=True)
# --- 1. Load only model-weight tensors from DCP ---
reader = FileSystemReader(str(src_iter_dir))
metadata = reader.read_metadata()
state_dict: dict[str, torch.Tensor] = {}
skipped_keys: list[str] = []
for key, item_meta in metadata.state_dict_metadata.items():
if isinstance(item_meta, BytesStorageMetadata):
skipped_keys.append(key)
continue
if _is_optimizer_key(key):
skipped_keys.append(key)
else:
state_dict[key] = torch.empty(item_meta.size, dtype=item_meta.properties.dtype, device="cpu")
logger.info(f"Loading {len(state_dict)} model-weight keys (skipping {len(skipped_keys)} optimizer/other keys)")
if skipped_keys:
logger.debug(f"Skipped keys (first 20): {skipped_keys[:20]}")
dcp.load(state_dict=state_dict, storage_reader=reader, no_dist=True)
# --- 2. Save model-only state dict to destination ---
writer = FileSystemWriter(str(dst_iter_dir), single_file_per_rank=False, thread_count=os.cpu_count())
dcp.save(state_dict=state_dict, storage_writer=writer, no_dist=True)
del state_dict
# --- 3. Copy non-DCP artefacts (run_config, tokenizer, train_state, etc.) ---
for name in ("run_config.yaml", "train_state.pt"):
src_file = src_iter_dir / name
if src_file.exists():
shutil.copy2(src_file, dst_iter_dir / name)
tokenizer_src = src_iter_dir / "tokenizer"
if tokenizer_src.is_dir():
shutil.copytree(tokenizer_src, dst_iter_dir / "tokenizer")
# --- 4. Write metadata.json (same format descriptor) ---
src_meta_json = src_iter_dir / "metadata.json"
if src_meta_json.exists():
shutil.copy2(src_meta_json, dst_iter_dir / "metadata.json")
else:
with open(dst_iter_dir / "metadata.json", "w") as f:
json.dump(
{
"sharded_backend": "torch_dist",
"sharded_backend_version": 1,
"common_backend": "torch",
"common_backend_version": 1,
},
f,
)
# --- 5. Write common.pt without optimizer metadata ---
src_common = src_iter_dir / "common.pt"
if src_common.exists():
common = torch.load(src_common, map_location="cpu", weights_only=False)
common.pop("optimizer", None)
common.pop("opt_param_scheduler", None)
if "content_metadata" in common:
common["content_metadata"].pop("distrib_optim_sharding_type", None)
torch.save(common, dst_iter_dir / "common.pt")
# --- 6. Write top-level files ---
src_latest = src_ckpt_dir / "latest_checkpointed_iteration.txt"
if src_latest.exists():
shutil.copy2(src_latest, dst_ckpt_dir / "latest_checkpointed_iteration.txt")
src_train_state = src_ckpt_dir / "latest_train_state.pt"
if src_train_state.exists():
shutil.copy2(src_train_state, dst_ckpt_dir / "latest_train_state.pt")
logger.info(f"Wrote optimizer-free checkpoint to {dst_ckpt_dir}")
return dst_ckpt_dir
|