Skip to content

Parallel test utils

clean_up_distributed()

Cleans up the distributed environment.

Destroys the process group and empties the CUDA cache.

Returns:

Type Description
None

None

Source code in bionemo/moco/testing/parallel_test_utils.py
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def clean_up_distributed() -> None:
    """Cleans up the distributed environment.

    Destroys the process group and empties the CUDA cache.

    Args:
        None

    Returns:
        None
    """
    if dist.is_initialized():
        dist.destroy_process_group()
    torch.cuda.empty_cache()

find_free_network_port(address='localhost')

Finds a free port for the specified address. Defaults to localhost.

Source code in bionemo/moco/testing/parallel_test_utils.py
31
32
33
34
35
36
37
38
39
40
def find_free_network_port(address: str = "localhost") -> int:
    """Finds a free port for the specified address. Defaults to localhost."""
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.bind((address, 0))
    addr_port = s.getsockname()
    s.close()
    if addr_port is None:
        # Could not find any free port.
        return None, None
    return addr_port

parallel_context(rank=0, world_size=1)

Context manager for torch distributed testing.

Sets up and cleans up the distributed environment, including the device mesh.

Parameters:

Name Type Description Default
rank int

The rank of the process. Defaults to 0.

0
world_size int

The world size of the distributed environment. Defaults to 1.

1

Yields:

Type Description

None

Source code in bionemo/moco/testing/parallel_test_utils.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@contextmanager
def parallel_context(
    rank: int = 0,
    world_size: int = 1,
):
    """Context manager for torch distributed testing.

    Sets up and cleans up the distributed environment, including the device mesh.

    Args:
        rank (int): The rank of the process. Defaults to 0.
        world_size (int): The world size of the distributed environment. Defaults to 1.

    Yields:
        None
    """
    with MonkeyPatch.context() as context:
        clean_up_distributed()

        # distributed and parallel state set up
        if not os.environ.get("MASTER_ADDR", None):
            context.setenv("MASTER_ADDR", DEFAULT_MASTER_ADDR)
        if not os.environ.get("MASTER_PORT", None):
            network_address, free_network_port = find_free_network_port(address=DEFAULT_MASTER_ADDR)
            context.setenv("MASTER_PORT", free_network_port if free_network_port is not None else DEFAULT_MASTER_PORT)
        context.setenv("RANK", str(rank))

        dist.init_process_group(backend="nccl", world_size=world_size)

        yield

        clean_up_distributed()