Skip to content

Megatron parallel state utils

This package contains utilities for managing the state of distributed model parallelism in Megatron and Apex.

In general you should just use the context manager distributed_model_parallel_state to manage the state of your test. This context manager will handle the setup and teardown of the distributed model parallel state for you.

Example usage:

from bionemo.testing import megatron_parallel_state_utils

def my_test():
    with megatron_parallel_state_utils.distributed_model_parallel_state():
        # your test code that requires megatron/apex parallel state to be set up here

clean_parallel_state_context()

Puts you into a clean parallel state, and again tears it down at the end.

Source code in bionemo/testing/megatron_parallel_state_utils.py
105
106
107
108
109
110
111
112
113
114
115
@contextmanager
def clean_parallel_state_context() -> Iterator[None]:
    """Puts you into a clean parallel state, and again tears it down at the end."""
    try:
        _teardown_apex_megatron_cuda()
        yield
    except Exception as e:
        # TODO (@skothenhill) verify this is a problem and that this is a solution. Had issues with keyboard interrupts being ignored inside context manager.
        raise Exception from e
    finally:
        _teardown_apex_megatron_cuda()

distributed_model_parallel_state(seed=42, devices=1, tensor_model_parallel_size=1, pipeline_model_parallel_size=1, pipeline_model_parallel_split_rank=0, context_parallel_size=1, interactive=False)

Context manager for handling creating and cleaning up distributed model parallel state for tests. Use like: with distributed_model_parallel_state(): # your test code here

After the block your state is cleaned up.

Source code in bionemo/testing/megatron_parallel_state_utils.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
@contextmanager
def distributed_model_parallel_state(
    seed: Optional[int] = 42,
    devices: int = 1,
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
    pipeline_model_parallel_split_rank: int = 0,
    context_parallel_size: int = 1,
    interactive: bool = False,
) -> Iterator[None]:
    """Context manager for handling creating and cleaning up distributed model parallel state for tests.
    Use like:
    with distributed_model_parallel_state():
        # your test code here
    # After the block your state is cleaned up.
    """  # noqa: D205
    initial_states: Optional[Any] = None

    try:
        _teardown_apex_megatron_cuda()
        _initialize_distributed_parallel_state(
            devices=devices,
            tensor_model_parallel_size=tensor_model_parallel_size,
            pipeline_model_parallel_size=pipeline_model_parallel_size,
            pipeline_model_parallel_split_rank=pipeline_model_parallel_split_rank,
            context_parallel_size=context_parallel_size,
            interactive=interactive,
        )
        # Our goal is to set required state on entry, and then restore current state on exit for the RNGs.
        #  there are two possibilities that are handled below:
        # 1. If the RNG state is not initialized, we need to set it up and then
        #     unset it on exit to restore the current state. We track that this is the case when `initial_states` is `None`.
        # 2. If the RNG state is initialized, we need to track this state and reset it on exit to be what it was on entry.
        #    We track that this is the case when `initial_states` is not `None`.
        if tp_random.get_cuda_rng_tracker().is_initialized():
            initial_states = tp_random.get_cuda_rng_tracker().get_states()
        if seed is not None:
            # Set the seed if provided, this case is valid whether or not the RNG had state previously.
            #  on exit the RNG state will be restored to what it was on entry.
            tp_random.model_parallel_cuda_manual_seed(seed)
        else:
            # This is the case where the RNG state is not initialized and no seed was provided.
            #  We need to raise an error in this case, as we cannot restore the RNG state on exit and we need a seed
            #  to initialize the RNG state to. This only happens if the user overrides the default seed and sets it
            #  to None, and additionally if the RNG state was not initialized externally, as there is a default seed of 42.
            if initial_states is None:
                raise ValueError(
                    "You must provide a seed if the initial parallel state is unset. "
                    "Either provide a seed or leave the default seed (rather setting to None) "
                    "or initialize the RNG state externally."
                )
        yield
    finally:
        if initial_states is not None:
            tp_random.get_cuda_rng_tracker().set_states(initial_states)
        else:
            # Reset to the unset state
            tp_random.get_cuda_rng_tracker().reset()
        _teardown_apex_megatron_cuda()

mock_distributed_parallel_state(world_size=8, rank=0, tensor_model_parallel_size=1, pipeline_model_parallel_size=1, virtual_pipeline_model_parallel_size=None, context_parallel_size=1, expert_model_parallel_size=1, seed=42)

A context manager that facilitates easy mocking of torch.distributed for an arbitrary GPU in a simulated cluster.

Key functions that are mocked
  • torch.distributed.new_group when backend="gloo" which doesn't support a backend="fake"
  • torch.distributed.destroy_process_group when backend="gloo" since new "gloo" groups are not actually made
  • torch._C._cuda_setDevice which changes the current device behind the scenes. We assign devices round-robin to support world_size > torch.cuda.device_count().

Outside of this mocking, a fake cluster is initialized using backend="fake" in torch.distributed. This sets up enough global state and environment for megatron to think that it is initializing a larger cluster with some settings where the current context has some user defined rank. You can then test the megatron state on a hypothetical rank in some large world size.

Parameters:

Name Type Description Default
world_size int

The world size (cluster size). Defaults to 8.

8
rank int

the GPU number globally in the cluster. Defaults to 0.

0
tensor_model_parallel_size int

tensor model parallel setting for megatron. Defaults to 1.

1
pipeline_model_parallel_size int

pipeline model parallel setting for megatron. Defaults to 1.

1
virtual_pipeline_model_parallel_size Optional[int]

virtual pipeline model parallel size for megatron. Defaults to None.

None
context_parallel_size int

context parallel size. Defaults to 1.

1
expert_model_parallel_size int

expert model parallel size. Defaults to 1.

1
seed int | None

seed for RNG state. Defaults to 42.

42
Source code in bionemo/testing/megatron_parallel_state_utils.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
@contextmanager
def mock_distributed_parallel_state(
    world_size: int = 8,
    rank: int = 0,
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
    virtual_pipeline_model_parallel_size: Optional[int] = None,
    context_parallel_size: int = 1,
    expert_model_parallel_size: int = 1,
    seed: int | None = 42,
):
    """A context manager that facilitates easy mocking of torch.distributed for an arbitrary GPU in a simulated cluster.

    Key functions that are mocked:
        * `torch.distributed.new_group` when `backend="gloo"` which doesn't support a `backend="fake"`
        * `torch.distributed.destroy_process_group` when `backend="gloo"` since new "gloo" groups are not actually made
        * `torch._C._cuda_setDevice` which changes the current device behind the scenes. We assign devices round-robin
            to support `world_size > torch.cuda.device_count()`.

    Outside of this mocking, a fake cluster is initialized using `backend="fake"` in `torch.distributed`. This sets up
        enough global state and environment for megatron to think that it is initializing a larger cluster with some
        settings where the current context has some user defined rank. You can then test the megatron state on a
        hypothetical rank in some large world size.

    Args:
        world_size: The world size (cluster size). Defaults to 8.
        rank: the GPU number globally in the cluster. Defaults to 0.
        tensor_model_parallel_size: tensor model parallel setting for megatron. Defaults to 1.
        pipeline_model_parallel_size: pipeline model parallel setting for megatron. Defaults to 1.
        virtual_pipeline_model_parallel_size: virtual pipeline model parallel size for megatron. Defaults to None.
        context_parallel_size: context parallel size. Defaults to 1.
        expert_model_parallel_size: expert model parallel size. Defaults to 1.
        seed: seed for RNG state. Defaults to 42.
    """
    # First set up mocks for torch.distributed state/info
    ori_device_count = torch.cuda.device_count()
    # Conditionally mock torch.distributed.new_group based on backend argument
    ori_dist_new_group = torch.distributed.new_group

    def mock_new_group(*args, **kwargs):
        if kwargs.get("backend") == "gloo":
            # Return a specific mock if backend is 'gloo'
            return MagicMock(name="gloo_group")
        else:
            # Return another mock or a different behavior for other backends
            return ori_dist_new_group(*args, **kwargs)

    ori_destroy_pg = torch.distributed.destroy_process_group

    def mock_destroy_gloo_group(pg=None):
        if isinstance(pg, MagicMock):
            return None
        ori_destroy_pg(pg)

    # The next mock is required to "set the device" to one that is greater than the number of actual GPUs
    #  the consequence of this mock is that the device is always dev 0
    ori_set_device = torch._C._cuda_setDevice

    def mock_set_device(device):
        if ori_device_count > 0:
            ori_set_device(device % ori_device_count)  # wrap around the request

    with (
        mock.patch("torch.distributed.new_group", side_effect=mock_new_group),
        mock.patch("torch.distributed.destroy_process_group", side_effect=mock_destroy_gloo_group),
        mock.patch("torch._C._cuda_setDevice", side_effect=mock_set_device),
    ):
        # Next set up state etc
        state_util = _MockMegatronParallelStateSingleton()  # static singleton class
        state_util.world_size = world_size
        state_util.rank = rank
        initial_states: Optional[Any] = None
        try:
            state_util.set_world_size(world_size=world_size, rank=rank)
            state_util.initialize_model_parallel(
                tensor_model_parallel_size=tensor_model_parallel_size,
                pipeline_model_parallel_size=pipeline_model_parallel_size,
                virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
                context_parallel_size=context_parallel_size,
                expert_model_parallel_size=expert_model_parallel_size,
            )
            # Our goal is to set required state on entry, and then restore current state on exit for the RNGs.
            #  there are two possibilities that are handled below:
            # 1. If the RNG state is not initialized, we need to set it up and then
            #     unset it on exit to restore the current state. We track that this is the case when `initial_states` is `None`.
            # 2. If the RNG state is initialized, we need to track this state and reset it on exit to be what it was on entry.
            #    We track that this is the case when `initial_states` is not `None`.
            if tp_random.get_cuda_rng_tracker().is_initialized():
                initial_states = tp_random.get_cuda_rng_tracker().get_states()
            if seed is not None:
                # Set the seed if provided, this case is valid whether or not the RNG had state previously.
                #  on exit the RNG state will be restored to what it was on entry.
                tp_random.model_parallel_cuda_manual_seed(seed)
            else:
                # This is the case where the RNG state is not initialized and no seed was provided.
                #  We need to raise an error in this case, as we cannot restore the RNG state on exit and we need a seed
                #  to initialize the RNG state to. This only happens if the user overrides the default seed and sets it
                #  to None, and additionally if the RNG state was not initialized externally, as there is a default seed of 42.
                if initial_states is None:
                    raise ValueError(
                        "You must provide a seed if the initial parallel state is unset. "
                        "Either provide a seed or leave the default seed (rather setting to None) "
                        "or initialize the RNG state externally."
                    )
            yield
        finally:
            if initial_states is not None:
                tp_random.get_cuda_rng_tracker().set_states(initial_states)
            else:
                # Reset to the unset state
                tp_random.get_cuda_rng_tracker().reset()
            state_util.destroy_model_parallel()