Skip to content

Zero3 conversion lib

Helper utility for converting ZeRO3 and ZeRO2 checkpoints to PyTorch.

ZeroModelState dataclass

A dataclass representing the state of a ZeRO model.

Attributes:

Name Type Description
buffers Dict

Buffers in the model state.

extra_states Dict

Extra states in the model state.

param_shapes List

Shapes of the parameters.

shared_params List

Shared parameters in the model state.

ds_version int

Version of the DeepSpeed checkpoint.

frozen_param_shapes Dict

Shapes of the frozen parameters.

frozen_param_fragments Dict

Fragments of the frozen parameters.

Source code in bionemo/evo2/utils/checkpoint/zero3_conversion_lib.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
@dataclass
class ZeroModelState:
    """A dataclass representing the state of a ZeRO model.

    Attributes:
        buffers (Dict): Buffers in the model state.
        extra_states (Dict): Extra states in the model state.
        param_shapes (List): Shapes of the parameters.
        shared_params (List): Shared parameters in the model state.
        ds_version (int): Version of the DeepSpeed checkpoint.
        frozen_param_shapes (Dict): Shapes of the frozen parameters.
        frozen_param_fragments (Dict): Fragments of the frozen parameters.
    """

    buffers: Dict
    extra_states: Dict
    param_shapes: List
    shared_params: List
    ds_version: int
    frozen_param_shapes: Dict
    frozen_param_fragments: Dict

atoi(text)

Converts a string to an integer if it is a digit, otherwise returns the string.

Parameters:

Name Type Description Default
text str

The text to be converted.

required

Returns:

Type Description

int or str: The converted integer or the original string.

Source code in bionemo/evo2/utils/checkpoint/zero3_conversion_lib.py
110
111
112
113
114
115
116
117
118
119
def atoi(text: str):
    """Converts a string to an integer if it is a digit, otherwise returns the string.

    Args:
        text (str): The text to be converted.

    Returns:
        int or str: The converted integer or the original string.
    """
    return int(text) if text.isdigit() else text

create_ds_output_path(rank)

Creates the output path for a DeepSpeed checkpoint.

Parameters:

Name Type Description Default
rank int

The rank to create the output path for.

required

Returns:

Name Type Description
str

The output path for the DeepSpeed checkpoint.

Source code in bionemo/evo2/utils/checkpoint/zero3_conversion_lib.py
187
188
189
190
191
192
193
194
195
196
def create_ds_output_path(rank: int):
    """Creates the output path for a DeepSpeed checkpoint.

    Args:
        rank (int): The rank to create the output path for.

    Returns:
        str: The output path for the DeepSpeed checkpoint.
    """
    return f"mp_rank_{rank:02}_model_states.pt"

create_zero3_model_state_path(dp_rank, mp_rank)

Creates the path for a ZeRO3 model state file.

Parameters:

Name Type Description Default
dp_rank int

The data parallel rank.

required
mp_rank int

The model parallel rank.

required

Returns:

Name Type Description
str

The path for the ZeRO3 model state file.

Source code in bionemo/evo2/utils/checkpoint/zero3_conversion_lib.py
199
200
201
202
203
204
205
206
207
208
209
def create_zero3_model_state_path(dp_rank: int, mp_rank: int):
    """Creates the path for a ZeRO3 model state file.

    Args:
        dp_rank (int): The data parallel rank.
        mp_rank (int): The model parallel rank.

    Returns:
        str: The path for the ZeRO3 model state file.
    """
    return f"zero_pp_rank_{dp_rank}_mp_rank_{mp_rank:02}_model_states.pt"

create_zero3_optim_state_path(dp_rank, mp_rank)

Creates the path for a ZeRO3 optimizer state file.

Parameters:

Name Type Description Default
dp_rank int

The data parallel rank.

required
mp_rank int

The model parallel rank.

required

Returns:

Name Type Description
str

The path for the ZeRO3 optimizer state file.

Source code in bionemo/evo2/utils/checkpoint/zero3_conversion_lib.py
212
213
214
215
216
217
218
219
220
221
222
def create_zero3_optim_state_path(dp_rank: int, mp_rank: int):
    """Creates the path for a ZeRO3 optimizer state file.

    Args:
        dp_rank (int): The data parallel rank.
        mp_rank (int): The model parallel rank.

    Returns:
        str: The path for the ZeRO3 optimizer state file.
    """
    return f"bf16_zero_pp_rank_{dp_rank}_mp_rank_{mp_rank:02}_optim_states.pt"

get_checkpoint_files(checkpoint_dir, glob_pattern)

Retrieves checkpoint files from a directory based on a glob pattern.

Parameters:

Name Type Description Default
checkpoint_dir str

The directory to search for checkpoint files.

required
glob_pattern str

The glob pattern to match files.

required

Returns:

Name Type Description
list

A sorted list of checkpoint files.

Raises:

Type Description
FileNotFoundError

If no files matching the glob pattern are found.

Source code in bionemo/evo2/utils/checkpoint/zero3_conversion_lib.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def get_checkpoint_files(checkpoint_dir: str, glob_pattern: str):
    """Retrieves checkpoint files from a directory based on a glob pattern.

    Args:
        checkpoint_dir (str): The directory to search for checkpoint files.
        glob_pattern (str): The glob pattern to match files.

    Returns:
        list: A sorted list of checkpoint files.

    Raises:
        FileNotFoundError: If no files matching the glob pattern are found.
    """
    # XXX: need to test that this simple glob rule works for multi-node setup too
    ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)

    if len(ckpt_files) == 0:
        raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")

    return ckpt_files

get_elapsed(t)

Converts elapsed time in seconds to a formatted string.

Parameters:

Name Type Description Default
t float

The elapsed time in seconds.

required

Returns:

Name Type Description
str

The formatted elapsed time as a string.

Source code in bionemo/evo2/utils/checkpoint/zero3_conversion_lib.py
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
def get_elapsed(t: float):
    """Converts elapsed time in seconds to a formatted string.

    Args:
        t (float): The elapsed time in seconds.

    Returns:
        str: The formatted elapsed time as a string.
    """
    minutes = t // 60
    seconds = t % 60
    if minutes > 0:
        total_time = f"{minutes:.0f}min{seconds:.0f}s"
    else:
        total_time = f"{seconds:.1f}s"
    return total_time

get_model_files_by_rank(checkpoint_dir, rank)

Retrieves model files for a specific rank from a checkpoint directory.

Parameters:

Name Type Description Default
checkpoint_dir str

The directory to search for model files.

required
rank int

The rank to search for.

required

Returns:

Name Type Description
list

A list of model files for the specified rank.

Source code in bionemo/evo2/utils/checkpoint/zero3_conversion_lib.py
161
162
163
164
165
166
167
168
169
170
171
def get_model_files_by_rank(checkpoint_dir: str, rank: int):
    """Retrieves model files for a specific rank from a checkpoint directory.

    Args:
        checkpoint_dir (str): The directory to search for model files.
        rank (int): The rank to search for.

    Returns:
        list: A list of model files for the specified rank.
    """
    return get_checkpoint_files(checkpoint_dir, f"*mp_rank_{rank:02}_model_states.pt")

get_model_state_file(checkpoint_dir, zero_stage)

Retrieves the model state file from a checkpoint directory based on the ZeRO stage.

Parameters:

Name Type Description Default
checkpoint_dir str

The directory to search for the model state file.

required
zero_stage int

The ZeRO stage to search for.

required

Returns:

Name Type Description
str

The path to the model state file.

Raises:

Type Description
FileNotFoundError

If the directory or model state file is not found.

ValueError

If the ZeRO stage is not supported.

Source code in bionemo/evo2/utils/checkpoint/zero3_conversion_lib.py
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
def get_model_state_file(checkpoint_dir: str, zero_stage: int):
    """Retrieves the model state file from a checkpoint directory based on the ZeRO stage.

    Args:
        checkpoint_dir (str): The directory to search for the model state file.
        zero_stage (int): The ZeRO stage to search for.

    Returns:
        str: The path to the model state file.

    Raises:
        FileNotFoundError: If the directory or model state file is not found.
        ValueError: If the ZeRO stage is not supported.
    """
    if not os.path.isdir(checkpoint_dir):
        raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")

    # there should be only one file
    if zero_stage <= 2:
        file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
    elif zero_stage == 3:
        file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
    else:
        raise ValueError(f"Unsupported zero stage {zero_stage}. Expected 1, 2, or 3")

    if not os.path.exists(file):
        raise FileNotFoundError(f"can't find model states file at '{file}'")

    return file

get_optim_files_by_rank(checkpoint_dir, rank)

Retrieves optimizer files for a specific rank from a checkpoint directory.

Parameters:

Name Type Description Default
checkpoint_dir str

The directory to search for optimizer files.

required
rank int

The rank to search for.

required

Returns:

Name Type Description
list

A list of optimizer files for the specified rank.

Source code in bionemo/evo2/utils/checkpoint/zero3_conversion_lib.py
174
175
176
177
178
179
180
181
182
183
184
def get_optim_files_by_rank(checkpoint_dir: str, rank: int):
    """Retrieves optimizer files for a specific rank from a checkpoint directory.

    Args:
        checkpoint_dir (str): The directory to search for optimizer files.
        rank (int): The rank to search for.

    Returns:
        list: A list of optimizer files for the specified rank.
    """
    return get_checkpoint_files(checkpoint_dir, f"*mp_rank_{rank:02}_optim_states.pt")

natural_keys(text)

Sorts a list in human order.

Parameters:

Name Type Description Default
text str

The text to be sorted.

required

Returns:

Name Type Description
list

The sorted list.

Note

alist.sort(key=natural_keys) sorts in human order. http://nedbatchelder.com/blog/200712/human_sorting.html (See Toothy's implementation in the comments)

Source code in bionemo/evo2/utils/checkpoint/zero3_conversion_lib.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def natural_keys(text: str):
    """Sorts a list in human order.

    Args:
        text (str): The text to be sorted.

    Returns:
        list: The sorted list.

    Note:
        alist.sort(key=natural_keys) sorts in human order.
        http://nedbatchelder.com/blog/200712/human_sorting.html
        (See Toothy's implementation in the comments)
    """
    return [atoi(c) for c in re.split(r"(\d+)", text)]

parse_model_states(files)

Parses model state files and returns a list of ZeroModelState objects.

Parameters:

Name Type Description Default
files Set[str]

A set of file paths to parse.

required

Returns:

Type Description

List[ZeroModelState]: A list of parsed ZeroModelState objects.

Raises:

Type Description
ValueError

If a file is not a model state checkpoint.

Source code in bionemo/evo2/utils/checkpoint/zero3_conversion_lib.py
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
def parse_model_states(files: Set[str]):
    """Parses model state files and returns a list of ZeroModelState objects.

    Args:
        files (Set[str]): A set of file paths to parse.

    Returns:
        List[ZeroModelState]: A list of parsed ZeroModelState objects.

    Raises:
        ValueError: If a file is not a model state checkpoint.
    """
    zero_model_states = []
    for file in files:
        state_dict = torch.load(file, map_location=device)

        if BUFFER_NAMES not in state_dict:
            raise ValueError(f"{file} is not a model state checkpoint")
        buffer_names = state_dict[BUFFER_NAMES]
        if debug:
            print_pid("Found buffers:", buffer_names)

        # recover just the buffers while restoring them to fp32 if they were saved in fp16
        buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}

        extra_states = {k: v for k, v in state_dict["module"].items() if k.endswith(EXTRA_STATE)}

        # collect parameters that are included in param_shapes
        param_shapes = state_dict[PARAM_SHAPES]
        param_names = []
        for s in param_shapes:
            for name in s.keys():
                param_names.append(name)

        # update with frozen parameters
        frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
        if frozen_param_shapes is not None:
            if debug:
                print_pid(f"Found frozen_param_shapes: {frozen_param_shapes}")
            param_names += list(frozen_param_shapes.keys())

        # handle shared params
        shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]

        ds_version = state_dict.get(DS_VERSION, None)

        frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)

        z_model_state = ZeroModelState(
            buffers=buffers,
            extra_states=extra_states,
            param_shapes=param_shapes,
            shared_params=shared_params,
            ds_version=ds_version,
            frozen_param_shapes=frozen_param_shapes,
            frozen_param_fragments=frozen_param_fragments,
        )
        zero_model_states.append(z_model_state)

    return zero_model_states

parse_optim_states(files, ds_checkpoint_dir)

Parses optimizer state files and returns the ZeRO stage, world size, and fp32 flat groups.

Parameters:

Name Type Description Default
files Set[str]

A set of file paths to parse.

required
ds_checkpoint_dir str

The directory containing the DeepSpeed checkpoint.

required

Returns:

Name Type Description
tuple

A tuple containing the ZeRO stage, world size, and fp32 flat groups.

Raises:

Type Description
ValueError

If a file is not a ZeRO checkpoint or if the number of files does not match the expected world size.

Source code in bionemo/evo2/utils/checkpoint/zero3_conversion_lib.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
def parse_optim_states(files: Set[str], ds_checkpoint_dir: str):
    """Parses optimizer state files and returns the ZeRO stage, world size, and fp32 flat groups.

    Args:
        files (Set[str]): A set of file paths to parse.
        ds_checkpoint_dir (str): The directory containing the DeepSpeed checkpoint.

    Returns:
        tuple: A tuple containing the ZeRO stage, world size, and fp32 flat groups.

    Raises:
        ValueError: If a file is not a ZeRO checkpoint or if the number of files does not match the expected world size.
    """
    total_files = len(files)
    state_dicts = []
    for f in files:
        state_dict = torch.load(f, map_location=device)
        # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
        # and also handle the case where it was already removed by another helper script
        state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
        state_dict[OPTIMIZER_STATE_DICT] = {
            FP32_FLAT_GROUPS: state_dict[OPTIMIZER_STATE_DICT][FP32_FLAT_GROUPS],
            ZERO_STAGE: state_dict[OPTIMIZER_STATE_DICT][ZERO_STAGE],
            PARTITION_COUNT: state_dict[OPTIMIZER_STATE_DICT][PARTITION_COUNT],
        }
        state_dicts.append(state_dict)

    if ZERO_STAGE not in state_dicts[0][OPTIMIZER_STATE_DICT]:
        raise ValueError(f"{files[0]} is not a zero checkpoint")
    zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
    world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]

    # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
    # parameters can be different from data parallelism for non-expert parameters. So we can just
    # use the max of the partition_count to get the dp world_size.

    if type(world_size) is list:
        world_size = max(world_size)

    if world_size != total_files:
        raise ValueError(
            f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
            "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
        )

    # the groups are named differently in each stage
    if zero_stage <= 2:
        fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
    elif zero_stage == 3:
        fp32_groups_key = FP32_FLAT_GROUPS
    else:
        raise ValueError(f"unknown zero stage {zero_stage}")

    if zero_stage <= 2:
        fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
    elif zero_stage == 3:
        # if there is more than one param group, there will be multiple flattened tensors - one
        # flattened tensor per group - for simplicity merge them into a single tensor
        #
        # XXX: could make the script more memory efficient for when there are multiple groups - it
        # will require matching the sub-lists of param_shapes for each param group flattened tensor

        fp32_flat_groups = [
            torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
        ]

    return zero_stage, world_size, fp32_flat_groups

print_pid(msg)

Prints the process ID along with a message.

Parameters:

Name Type Description Default
msg str

The message to be printed.

required
Source code in bionemo/evo2/utils/checkpoint/zero3_conversion_lib.py
100
101
102
103
104
105
106
107
def print_pid(msg: str):
    """Prints the process ID along with a message.

    Args:
        msg (str): The message to be printed.
    """
    pid = os.getpid()
    print(f"{pid=}:{msg}")

process_single_rank(rank, ds_checkpoint_dir, output_dir, overwrite=False, exclude_frozen_parameters=False)

Processes a single rank to gather and save the state dictionary.

Parameters:

Name Type Description Default
rank int

The rank to process.

required
ds_checkpoint_dir str

Path to the DeepSpeed checkpoint folder.

required
output_dir str

Directory to save the output.

required
overwrite bool

Whether to overwrite existing files. Default is False.

False
exclude_frozen_parameters bool

Whether to exclude frozen parameters. Default is False.

False
Source code in bionemo/evo2/utils/checkpoint/zero3_conversion_lib.py
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
def process_single_rank(
    rank: int,
    ds_checkpoint_dir: str,
    output_dir: str,
    overwrite: bool = False,
    exclude_frozen_parameters: bool = False,
):
    """Processes a single rank to gather and save the state dictionary.

    Args:
        rank (int): The rank to process.
        ds_checkpoint_dir (str): Path to the DeepSpeed checkpoint folder.
        output_dir (str): Directory to save the output.
        overwrite (bool): Whether to overwrite existing files. Default is False.
        exclude_frozen_parameters (bool): Whether to exclude frozen parameters. Default is False.
    """
    print_pid(f"Gathering rank {rank} state_dict...")

    start = time.time()
    output_path = os.path.join(output_dir, create_ds_output_path(rank))
    if os.path.exists(output_path) and not overwrite:
        print_pid(f"Output path {output_path} exists, skipping")
        return

    print_pid(f" -> Gathering data parallel partitions for mp rank {rank}...")

    if os.environ.get("ZERO3_CONVERSION_DEBUG", "0") == "1":
        breakpoint()

    state_dict = _get_fp32_state_dict_from_zero_checkpoint(
        ds_checkpoint_dir=ds_checkpoint_dir, rank=rank, exclude_frozen_parameters=exclude_frozen_parameters
    )
    print_pid(f" -> Done processing rank {rank} state_dict, gathered {len(state_dict)} params")

    checkpoint = {
        "module": state_dict,
        "param_shapes": OrderedDict(),
        "dp_world_size": 1,
    }

    for param, value in state_dict.items():
        if isinstance(value, torch.Tensor):
            checkpoint["param_shapes"][param] = value.shape

    print_pid(f" -> Saving mp rank {rank} checkpoint to {output_path}")
    torch.save(checkpoint, f"{output_path}")

    total_time = get_elapsed(time.time() - start)
    print_pid(f" -> rank {rank} took {total_time}")

profile_memory_decorator(func)

A decorator to profile memory usage of a function.

Parameters:

Name Type Description Default
func Iterable

The function to be decorated.

required

Returns:

Name Type Description
wrapper

The decorated function with memory profiling.

Source code in bionemo/evo2/utils/checkpoint/zero3_conversion_lib.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def profile_memory_decorator(func: Iterable):
    """A decorator to profile memory usage of a function.

    Args:
        func (Iterable): The function to be decorated.

    Returns:
        wrapper: The decorated function with memory profiling.
    """

    def profile_memory():
        pid = os.getpid()
        process = psutil.Process(pid)
        memory_info = process.memory_info()
        print_pid(f"{pid}: RSS = {memory_info.rss / 1024**2:.2f} MB")

    def wrapper(*args, **kwargs):
        profile_memory()
        func(*args, **kwargs)
        profile_memory()

    return wrapper

zero3_partitioned_param_info(unpartitioned_numel, world_size)

Returns the partitioned and padding number of elements for a parameter.

Parameters:

Name Type Description Default
unpartitioned_numel int

The number of elements in the unpartitioned parameter.

required
world_size int

The world size.

required

Returns:

Name Type Description
tuple

A tuple containing the partitioned number of elements and the padding number of elements.

Source code in bionemo/evo2/utils/checkpoint/zero3_conversion_lib.py
451
452
453
454
455
456
457
458
459
460
461
462
463
464
def zero3_partitioned_param_info(unpartitioned_numel: int, world_size: int):
    """Returns the partitioned and padding number of elements for a parameter.

    Args:
        unpartitioned_numel (int): The number of elements in the unpartitioned parameter.
        world_size (int): The world size.

    Returns:
        tuple: A tuple containing the partitioned number of elements and the padding number of elements.
    """
    remainder = unpartitioned_numel % world_size
    padding_numel = (world_size - remainder) if remainder else 0
    partitioned_numel = math.ceil(unpartitioned_numel / world_size)
    return partitioned_numel, padding_numel