Skip to content

Parallel test utils

clean_up_distributed()

Cleans up the distributed environment.

Destroys the process group and empties the CUDA cache.

Returns:

Type Description
None

None

Source code in bionemo/moco/testing/parallel_test_utils.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def clean_up_distributed() -> None:
    """Cleans up the distributed environment.

    Destroys the process group and empties the CUDA cache.

    Args:
        None

    Returns:
        None
    """
    if dist.is_initialized():
        dist.destroy_process_group()
    torch.cuda.empty_cache()

parallel_context(rank=0, world_size=1)

Context manager for torch distributed testing.

Sets up and cleans up the distributed environment, including the device mesh.

Parameters:

Name Type Description Default
rank int

The rank of the process. Defaults to 0.

0
world_size int

The world size of the distributed environment. Defaults to 1.

1

Yields:

Type Description

None

Source code in bionemo/moco/testing/parallel_test_utils.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
@contextmanager
def parallel_context(
    rank: int = 0,
    world_size: int = 1,
):
    """Context manager for torch distributed testing.

    Sets up and cleans up the distributed environment, including the device mesh.

    Args:
        rank (int): The rank of the process. Defaults to 0.
        world_size (int): The world size of the distributed environment. Defaults to 1.

    Yields:
        None
    """
    with MonkeyPatch.context() as context:
        clean_up_distributed()

        # distributed and parallel state set up
        if not os.environ.get("MASTER_ADDR", None):
            context.setenv("MASTER_ADDR", DEFAULT_MASTER_ADDR)
        if not os.environ.get("MASTER_PORT", None):
            context.setenv("MASTER_PORT", DEFAULT_MASTER_PORT)
        context.setenv("RANK", str(rank))

        dist.init_process_group(backend="nccl", world_size=world_size)

        yield

        clean_up_distributed()