Skip to content

Convert zero3 to zero1

convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_dir, tag=None, exclude_frozen_parameters=False, mp_size=8, overwrite=False, num_workers=1, ranks_to_process=None)

Converts a DeepSpeed Zero-3 checkpoint to a PyTorch FP32 state_dict.

Parameters:

Name Type Description Default
checkpoint_dir str

Path to the desired checkpoint folder.

required
output_dir str

Directory to save the PyTorch FP32 state_dict output files.

required
tag Optional[str]

Checkpoint tag used as a unique identifier or sub-directory that contains the checkpoint.

None
exclude_frozen_parameters bool

Whether to exclude frozen parameters.

False
mp_size int

Model parallel size of the source checkpoint.

8
overwrite bool

Whether to overwrite existing MP shards.

False
num_workers int

Number of workers to use for processing.

1
ranks_to_process Optional[List[int]]

List of ranks to process.

None

Raises:

Type Description
FileNotFoundError

If the checkpoint directory does not exist.

Source code in bionemo/evo2/utils/checkpoint/convert_zero3_to_zero1.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def convert_zero_checkpoint_to_fp32_state_dict(
    checkpoint_dir: str,
    output_dir: str,
    tag: Optional[str] = None,
    exclude_frozen_parameters: bool = False,
    mp_size: int = 8,
    overwrite: bool = False,
    num_workers: int = 1,
    ranks_to_process: Optional[List[int]] = None,
):
    """Converts a DeepSpeed Zero-3 checkpoint to a PyTorch FP32 state_dict.

    Args:
        checkpoint_dir (str): Path to the desired checkpoint folder.
        output_dir (str): Directory to save the PyTorch FP32 state_dict output files.
        tag (Optional[str]): Checkpoint tag used as a unique identifier or sub-directory that contains the checkpoint.
        exclude_frozen_parameters (bool): Whether to exclude frozen parameters.
        mp_size (int): Model parallel size of the source checkpoint.
        overwrite (bool): Whether to overwrite existing MP shards.
        num_workers (int): Number of workers to use for processing.
        ranks_to_process (Optional[List[int]]): List of ranks to process.

    Raises:
        FileNotFoundError: If the checkpoint directory does not exist.
    """
    ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) if tag is not None else checkpoint_dir

    if not os.path.isdir(ds_checkpoint_dir):
        raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")

    output_dir = os.path.join(output_dir, tag) if tag is not None else output_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)

    num_workers = min(num_workers, mp_size)

    if ranks_to_process is not None:
        ranks_to_process = list(ranks_to_process)
        assert len(ranks_to_process) <= mp_size, f"Expected {mp_size} ranks to process, got {len(ranks_to_process)}"
        assert all(0 <= r < mp_size for r in ranks_to_process), (
            f"Expected ranks to be in range [0, {mp_size}), got {ranks_to_process}"
        )
    else:
        ranks_to_process = list(range(mp_size))

    print(f"Processing ranks: {ranks_to_process}", flush=True)

    start = time.time()
    if num_workers > 1:
        with Pool(num_workers) as p:
            p.starmap(
                process_single_rank,
                [(i, ds_checkpoint_dir, output_dir, overwrite, exclude_frozen_parameters) for i in ranks_to_process],
            )
    else:
        for i in ranks_to_process:
            process_single_rank(i, ds_checkpoint_dir, output_dir, overwrite, exclude_frozen_parameters)

    total_time = get_elapsed(time.time() - start)
    print(f"All done!\n-> Total time: {total_time}\n-> All outputs written to {os.path.abspath(output_dir)}")