Skip to content

Megatron utils

average_losses_across_data_parallel_group(losses, with_context_parallel=False)

Reduce a tensor of losses across all GPUs.

Source code in bionemo/llm/utils/megatron_utils.py
39
40
41
42
43
44
45
46
47
48
49
50
51
def average_losses_across_data_parallel_group(losses, with_context_parallel: bool = False):
    """Reduce a tensor of losses across all GPUs."""
    averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses])
    # Reduce across the DP (or optionally, the flattened DP + CP) group.
    # Refer to the ring attention algorithm on why we always must reduce across the CP group.
    torch.distributed.all_reduce(
        averaged_losses, group=parallel_state.get_data_parallel_group(with_context_parallel=with_context_parallel)
    )
    averaged_losses = averaged_losses / torch.distributed.get_world_size(
        # Only average losses across the data parallel group, not the context parallel group!
        group=parallel_state.get_data_parallel_group()
    )
    return averaged_losses

is_only_data_parallel()

Checks to see if you are in a distributed megatron environment with only data parallelism active.

This is useful if you are working on a model, loss, etc and you know that you do not yet support megatron model parallelism. You can test that the only kind of parallelism in use is data parallelism.

Returns:

Type Description
bool

True if data parallel is the only parallel mode, False otherwise.

Source code in bionemo/llm/utils/megatron_utils.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def is_only_data_parallel() -> bool:
    """Checks to see if you are in a distributed megatron environment with only data parallelism active.

    This is useful if you are working on a model, loss, etc and you know that you do not yet support megatron model
    parallelism. You can test that the only kind of parallelism in use is data parallelism.

    Returns:
        True if data parallel is the only parallel mode, False otherwise.
    """
    if not (torch.distributed.is_available() and parallel_state.is_initialized()):
        raise RuntimeError("This function is only defined within an initialized megatron parallel environment.")
    # Idea: when world_size == data_parallel_world_size, then you know that you are fully DDP, which means you are not
    #  using model parallelism (meaning virtual GPUs composed of several underlying GPUs that you need to reduce over).

    world_size: int = torch.distributed.get_world_size()
    dp_world_size: int = parallel_state.get_data_parallel_world_size()
    return world_size == dp_world_size