Skip to content

Megatron parallel state utils

This package contains utilities for managing the state of distributed model parallelism in Megatron and Apex.

In general you should just use the context manager distributed_model_parallel_state to manage the state of your test. This context manager will handle the setup and teardown of the distributed model parallel state for you.

Example usage:

from bionemo.testing import megatron_parallel_state_utils

def my_test():
    with megatron_parallel_state_utils.distributed_model_parallel_state():
        # your test code that requires megatron/apex parallel state to be set up here

_MockMegatronParallelStateSingleton

Source code in bionemo/testing/megatron_parallel_state_utils.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
class _MockMegatronParallelStateSingleton:
    _instance = None

    def __init__(
        self,
        world_size=torch.cuda.device_count(),
        rank=int(os.getenv("LOCAL_RANK", 0)),
        inited=False,
        store=FakeStore(),
    ):
        """A singleton to deal with global megatron state for simulating a fake cluster.

        Args:
            world_size: the cluster size. Defaults to torch.cuda.device_count().
            rank: rank of this node. Defaults to int(os.getenv("LOCAL_RANK", 0)).
            inited: if this global cluster has been initiated. Defaults to False.
            store: the FakeStore for process groups. Defaults to FakeStore().
        """
        self.world_size = world_size
        self.rank = rank
        self.inited = inited
        # Fake store idea: see https://github.com/pytorch/pytorch/blob/main/test/distributed/test_fake_pg.py
        self.store = store

    def __new__(cls):
        # Makes this a singleton
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def initialize_distributed(self):
        torch.cuda.set_device(self.rank % self.world_size)
        # Fake store idea: see https://github.com/pytorch/pytorch/blob/main/test/distributed/test_fake_pg.py
        torch.distributed.init_process_group(
            backend="fake",
            world_size=self.world_size,
            rank=self.rank,
            store=self.store,
        )
        self.inited = True

    def set_world_size(self, world_size=None, rank=None):
        self.world_size = torch.cuda.device_count() if world_size is None else world_size
        if torch.distributed.is_initialized() and self.world_size != torch.distributed.get_world_size():
            torch.distributed.destroy_process_group()

        if rank is None:
            self.rank = int(os.environ.get("LOCAL_RANK", 0))
            if self.rank >= self.world_size:
                self.rank = -1
        else:
            self.rank = rank

    def destroy_model_parallel(self):
        if not self.inited:
            return
        # torch.distributed.barrier()
        parallel_state.destroy_model_parallel()
        self.inited = False
        torch.distributed.destroy_process_group()

    def initialize_model_parallel(
        self,
        tensor_model_parallel_size=1,
        pipeline_model_parallel_size=1,
        virtual_pipeline_model_parallel_size=None,
        **kwargs,
    ):
        parallel_state.destroy_model_parallel()
        self.initialize_distributed()
        parallel_state.initialize_model_parallel(
            tensor_model_parallel_size,
            pipeline_model_parallel_size,
            virtual_pipeline_model_parallel_size,
            **kwargs,
        )
        self.inited = True

__init__(world_size=torch.cuda.device_count(), rank=int(os.getenv('LOCAL_RANK', 0)), inited=False, store=FakeStore())

A singleton to deal with global megatron state for simulating a fake cluster.

Parameters:

Name Type Description Default
world_size

the cluster size. Defaults to torch.cuda.device_count().

device_count()
rank

rank of this node. Defaults to int(os.getenv("LOCAL_RANK", 0)).

int(getenv('LOCAL_RANK', 0))
inited

if this global cluster has been initiated. Defaults to False.

False
store

the FakeStore for process groups. Defaults to FakeStore().

FakeStore()
Source code in bionemo/testing/megatron_parallel_state_utils.py
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
def __init__(
    self,
    world_size=torch.cuda.device_count(),
    rank=int(os.getenv("LOCAL_RANK", 0)),
    inited=False,
    store=FakeStore(),
):
    """A singleton to deal with global megatron state for simulating a fake cluster.

    Args:
        world_size: the cluster size. Defaults to torch.cuda.device_count().
        rank: rank of this node. Defaults to int(os.getenv("LOCAL_RANK", 0)).
        inited: if this global cluster has been initiated. Defaults to False.
        store: the FakeStore for process groups. Defaults to FakeStore().
    """
    self.world_size = world_size
    self.rank = rank
    self.inited = inited
    # Fake store idea: see https://github.com/pytorch/pytorch/blob/main/test/distributed/test_fake_pg.py
    self.store = store

_reset_microbatch_calculator()

Resets _GLOBAL_NUM_MICROBATCHES_CALCULATOR in megatron which is used in NeMo to initilised model parallel in nemo.collections.nlp.modules.common.megatron.megatron_init.initialize_model_parallel_for_nemo

Source code in bionemo/testing/megatron_parallel_state_utils.py
60
61
62
63
64
def _reset_microbatch_calculator():
    """Resets _GLOBAL_NUM_MICROBATCHES_CALCULATOR in megatron which is used in NeMo to initilised model parallel in
    nemo.collections.nlp.modules.common.megatron.megatron_init.initialize_model_parallel_for_nemo
    """  # noqa: D205, D415
    megatron.core.num_microbatches_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None

clean_up_distributed_and_parallel_states()

Clean up parallel states, torch.distributed and torch cuda cache.

Source code in bionemo/testing/megatron_parallel_state_utils.py
67
68
69
70
71
72
73
def clean_up_distributed_and_parallel_states():
    """Clean up parallel states, torch.distributed and torch cuda cache."""
    _reset_microbatch_calculator()
    parallel_state.destroy_model_parallel()  # destroy parallel state before distributed
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
    torch.cuda.empty_cache()

distributed_model_parallel_state(seed=42, rank=0, world_size=1, backend='nccl', **initialize_model_parallel_kwargs)

Context manager for torch distributed and parallel state testing.

Parameters:

Name Type Description Default
seed int

random seed to be passed into tensor_parallel.random (https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/random.py). default to 42.

42
rank int

global rank of the current cuda device. default to 0.

0
world_size int

world size or number of devices. default to 1.

1
backend str

backend to torch.distributed.init_process_group. default to 'nccl'.

'nccl'
**initialize_model_parallel_kwargs

kwargs to be passed into initialize_model_parallel (https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py).

{}
Source code in bionemo/testing/megatron_parallel_state_utils.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
@contextmanager
def distributed_model_parallel_state(
    seed: int = 42,
    rank: int = 0,
    world_size: int = 1,
    backend: str = "nccl",
    **initialize_model_parallel_kwargs,
):
    """Context manager for torch distributed and parallel state testing.

    Args:
        seed (int): random seed to be passed into tensor_parallel.random (https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/random.py). default to 42.
        rank (int): global rank of the current cuda device. default to 0.
        world_size (int): world size or number of devices. default to 1.
        backend (str): backend to torch.distributed.init_process_group. default to 'nccl'.
        **initialize_model_parallel_kwargs: kwargs to be passed into initialize_model_parallel (https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py).
    """
    with MonkeyPatch.context() as context:
        initial_states = None
        try:
            clean_up_distributed_and_parallel_states()

            # distributed and parallel state set up
            if not os.environ.get("MASTER_ADDR", None):
                context.setenv("MASTER_ADDR", DEFAULT_MASTER_ADDR)
            if not os.environ.get("MASTER_PORT", None):
                context.setenv("MASTER_PORT", DEFAULT_MASTER_PORT)
            if not os.environ.get("NCCL_TIMEOUT", None):
                context.setenv("NCCL_TIMEOUT", DEFAULT_NCCL_TIMEOUT)
            context.setenv("RANK", str(rank))

            torch.distributed.init_process_group(backend=backend, world_size=world_size)
            parallel_state.initialize_model_parallel(**initialize_model_parallel_kwargs)

            # tensor parallel random seed set up
            # do not call torch.cuda.manual_seed after so!
            if tp_random.get_cuda_rng_tracker().is_initialized():
                initial_states = tp_random.get_cuda_rng_tracker().get_states()
            if seed is not None:
                tp_random.model_parallel_cuda_manual_seed(seed)

            yield
        finally:
            # restore/unset tensor parallel random seed
            if initial_states is not None:
                tp_random.get_cuda_rng_tracker().set_states(initial_states)
            else:
                # Reset to the unset state
                tp_random.get_cuda_rng_tracker().reset()

            clean_up_distributed_and_parallel_states()

mock_distributed_parallel_state(world_size=8, rank=0, tensor_model_parallel_size=1, pipeline_model_parallel_size=1, virtual_pipeline_model_parallel_size=None, context_parallel_size=1, expert_model_parallel_size=1, seed=42)

A context manager that facilitates easy mocking of torch.distributed for an arbitrary GPU in a simulated cluster.

Key functions that are mocked
  • torch.distributed.new_group when backend="gloo" which doesn't support a backend="fake"
  • torch.distributed.destroy_process_group when backend="gloo" since new "gloo" groups are not actually made
  • torch._C._cuda_setDevice which changes the current device behind the scenes. We assign devices round-robin to support world_size > torch.cuda.device_count().

Outside of this mocking, a fake cluster is initialized using backend="fake" in torch.distributed. This sets up enough global state and environment for megatron to think that it is initializing a larger cluster with some settings where the current context has some user defined rank. You can then test the megatron state on a hypothetical rank in some large world size.

Parameters:

Name Type Description Default
world_size int

The world size (cluster size). Defaults to 8.

8
rank int

the GPU number globally in the cluster. Defaults to 0.

0
tensor_model_parallel_size int

tensor model parallel setting for megatron. Defaults to 1.

1
pipeline_model_parallel_size int

pipeline model parallel setting for megatron. Defaults to 1.

1
virtual_pipeline_model_parallel_size Optional[int]

virtual pipeline model parallel size for megatron. Defaults to None.

None
context_parallel_size int

context parallel size. Defaults to 1.

1
expert_model_parallel_size int

expert model parallel size. Defaults to 1.

1
seed int | None

seed for RNG state. Defaults to 42.

42
Source code in bionemo/testing/megatron_parallel_state_utils.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
@contextmanager
def mock_distributed_parallel_state(
    world_size: int = 8,
    rank: int = 0,
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
    virtual_pipeline_model_parallel_size: Optional[int] = None,
    context_parallel_size: int = 1,
    expert_model_parallel_size: int = 1,
    seed: int | None = 42,
):
    """A context manager that facilitates easy mocking of torch.distributed for an arbitrary GPU in a simulated cluster.

    Key functions that are mocked:
        * `torch.distributed.new_group` when `backend="gloo"` which doesn't support a `backend="fake"`
        * `torch.distributed.destroy_process_group` when `backend="gloo"` since new "gloo" groups are not actually made
        * `torch._C._cuda_setDevice` which changes the current device behind the scenes. We assign devices round-robin
            to support `world_size > torch.cuda.device_count()`.

    Outside of this mocking, a fake cluster is initialized using `backend="fake"` in `torch.distributed`. This sets up
        enough global state and environment for megatron to think that it is initializing a larger cluster with some
        settings where the current context has some user defined rank. You can then test the megatron state on a
        hypothetical rank in some large world size.

    Args:
        world_size: The world size (cluster size). Defaults to 8.
        rank: the GPU number globally in the cluster. Defaults to 0.
        tensor_model_parallel_size: tensor model parallel setting for megatron. Defaults to 1.
        pipeline_model_parallel_size: pipeline model parallel setting for megatron. Defaults to 1.
        virtual_pipeline_model_parallel_size: virtual pipeline model parallel size for megatron. Defaults to None.
        context_parallel_size: context parallel size. Defaults to 1.
        expert_model_parallel_size: expert model parallel size. Defaults to 1.
        seed: seed for RNG state. Defaults to 42.
    """
    # First set up mocks for torch.distributed state/info
    ori_device_count = torch.cuda.device_count()
    # Conditionally mock torch.distributed.new_group based on backend argument
    ori_dist_new_group = torch.distributed.new_group

    def mock_new_group(*args, **kwargs):
        if kwargs.get("backend") == "gloo":
            # Return a specific mock if backend is 'gloo'
            return MagicMock(name="gloo_group")
        else:
            # Return another mock or a different behavior for other backends
            return ori_dist_new_group(*args, **kwargs)

    ori_destroy_pg = torch.distributed.destroy_process_group

    def mock_destroy_gloo_group(pg=None):
        if isinstance(pg, MagicMock):
            return None
        ori_destroy_pg(pg)

    # The next mock is required to "set the device" to one that is greater than the number of actual GPUs
    #  the consequence of this mock is that the device is always dev 0
    ori_set_device = torch._C._cuda_setDevice

    def mock_set_device(device):
        if ori_device_count > 0:
            ori_set_device(device % ori_device_count)  # wrap around the request

    with (
        mock.patch("torch.distributed.new_group", side_effect=mock_new_group),
        mock.patch("torch.distributed.destroy_process_group", side_effect=mock_destroy_gloo_group),
        mock.patch("torch._C._cuda_setDevice", side_effect=mock_set_device),
    ):
        # Next set up state etc
        state_util = _MockMegatronParallelStateSingleton()  # static singleton class
        state_util.world_size = world_size
        state_util.rank = rank
        initial_states: Optional[Any] = None
        try:
            state_util.set_world_size(world_size=world_size, rank=rank)
            state_util.initialize_model_parallel(
                tensor_model_parallel_size=tensor_model_parallel_size,
                pipeline_model_parallel_size=pipeline_model_parallel_size,
                virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
                context_parallel_size=context_parallel_size,
                expert_model_parallel_size=expert_model_parallel_size,
            )
            # Our goal is to set required state on entry, and then restore current state on exit for the RNGs.
            #  there are two possibilities that are handled below:
            # 1. If the RNG state is not initialized, we need to set it up and then
            #     unset it on exit to restore the current state. We track that this is the case when `initial_states` is `None`.
            # 2. If the RNG state is initialized, we need to track this state and reset it on exit to be what it was on entry.
            #    We track that this is the case when `initial_states` is not `None`.
            if tp_random.get_cuda_rng_tracker().is_initialized():
                initial_states = tp_random.get_cuda_rng_tracker().get_states()
            if seed is not None:
                # Set the seed if provided, this case is valid whether or not the RNG had state previously.
                #  on exit the RNG state will be restored to what it was on entry.
                tp_random.model_parallel_cuda_manual_seed(seed)
            else:
                # This is the case where the RNG state is not initialized and no seed was provided.
                #  We need to raise an error in this case, as we cannot restore the RNG state on exit and we need a seed
                #  to initialize the RNG state to. This only happens if the user overrides the default seed and sets it
                #  to None, and additionally if the RNG state was not initialized externally, as there is a default seed of 42.
                if initial_states is None:
                    raise ValueError(
                        "You must provide a seed if the initial parallel state is unset. "
                        "Either provide a seed or leave the default seed (rather setting to None) "
                        "or initialize the RNG state externally."
                    )
            yield
        finally:
            if initial_states is not None:
                tp_random.get_cuda_rng_tracker().set_states(initial_states)
            else:
                # Reset to the unset state
                tp_random.get_cuda_rng_tracker().reset()
            state_util.destroy_model_parallel()