Source code for megatron.energon.worker

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

import hashlib
import json
import multiprocessing
from typing import Any, ClassVar, Dict, List, Optional, TextIO, TypeVar

import torch.distributed
import torch.utils.data

from megatron.energon.dataclass_slots import dataclass_slots
from megatron.energon.epathlib import EPath

__all__ = ("WorkerConfig",)

T = TypeVar("T")


[docs] @dataclass_slots(eq=False) class WorkerConfig: """ Provides information about the current worker and the global configuration. This gives each data parallel rank its proper config. Every `rank` (up to `world_size-1`) must be used. If set wrong, the datasets might yield the same data or data might be missing, as data is split over the data parallel ranks with this config! You may set the same rank, if you need multiple ranks to retrieve the same data. """ #: The data parallel rank/id of the current process. rank: int #: The total number of data parallel processes. world_size: int #: The number of workers per rank. May be 0 to disable worker processes. num_workers: int #: If not using all ranks for data parallel, set this to the corresponding group. data_parallel_group: Optional[torch.distributed.ProcessGroup] = None #: The id offset of the current worker. e.g. the worker may live as `worker_info.id=0`, but # actually yield samples for id=1 (i.e. worker_id_offset=1). Required to support restoring the # worker state if last emitted sample was not for worker_id=0. Required by SavableDataLoader to # restore the worker state. Is only set to nonzero within a worker process. worker_id_offset: ClassVar[int] = 0 #: The following seed_offset is used used at two points in the code. # 1. The seed_offset in the worker_config that is passed to the dataset initialization, is used # to set the seed for the dataset shuffling and shuffled blending (All code that uses WorkerRng). # 2. The worker_config passed to the data loader initialization, is used to set the seed for the # torch, numpy and random libraries. This does not affect the dataset shuffling, but only the # user code (e.g. code in TaskEncoder). seed_offset: int = 0 #: The path to the debug file for the current worker. Should contain "{worker_id}" and "{pid}" # to separate the workers. worker_debug_path: Optional[str] = None #: Log level for worker logging. worker_log_level: int = 0 #: The opened file for the current worker. Should not be set from outside. _worker_debug_file: Optional[TextIO] = None #: worker_id of the opened worker debug file _worker_debug_file_worker_id: Optional[int] = None #: The current sample index within the current iterating worker _sample_index_stack: ClassVar[Optional[List[int]]] = None #: The current worker config within the current iterating worker active_worker_config: ClassVar[Optional["WorkerConfig"]] = None #: The global rank override for the worker. Required for restoring samples. _worker_override_global_rank: ClassVar[Optional[List[int]]] = None
[docs] def worker_activate(self, sample_index: int, override_global_rank: Optional[int] = None): """Activates the worker config for the current worker and sets it as actively iterating. Must be called before next() call on the datasets.""" assert WorkerConfig.active_worker_config is None WorkerConfig._sample_index_stack = [sample_index] WorkerConfig.active_worker_config = self WorkerConfig._worker_override_global_rank = override_global_rank
[docs] def worker_push_sample_index(self, sample_index: int): """Pushes a new sample index to the sample index stack. Should be set by wrapping datasets before calling inners.""" assert WorkerConfig.active_worker_config is not None WorkerConfig._sample_index_stack.append(sample_index)
[docs] def worker_pop_sample_index(self): """Pushes a new sample index to the sample index stack. Should be set by wrapping datasets before calling inners.""" assert WorkerConfig.active_worker_config is not None return WorkerConfig._sample_index_stack.pop()
[docs] def worker_deactivate(self): """Deactivates the worker config for the current worker and deactivates it for iterating. Must be called after next() call on the datasets.""" if WorkerConfig.active_worker_config is not None: assert len(WorkerConfig._sample_index_stack) == 1, ( f"Sample index stack not empty: {WorkerConfig._sample_index_stack}" ) WorkerConfig._sample_index_stack = None WorkerConfig.active_worker_config = None WorkerConfig._worker_override_global_rank = None
@property def active_worker_sample_index(self) -> int: """Returns the current sample index for the actively iterating worker.""" # Internal sample index is for the local worker. If using multiple workers per rank, this # must be multiplied by the number of workers and offset by the local worker index. return ( WorkerConfig._sample_index_stack[-1] * max(self.num_workers, 1) + self.rank_worker_id() ) @property def active_worker_batch_index(self) -> int: """Returns the current batch index for the actively iterating worker.""" # Internal batch index is for the local worker. If using multiple workers per rank, this # must be multiplied by the number of workers and offset by the local worker index. return ( WorkerConfig._sample_index_stack[0] * max(self.num_workers, 1) + self.rank_worker_id() )
[docs] def global_rank(self) -> int: """Returns the global rank of this worker config but as a global rank, not as a rank within the data parallel group.""" if self.data_parallel_group is None: return self.rank return torch.distributed.get_global_rank(self.data_parallel_group, self.rank)
def __eq__(self, other): """Do not compare everything to check for equal config""" if not isinstance(other, WorkerConfig): return NotImplementedError() return all( [ self.rank == other.rank, self.world_size == other.world_size, self.num_workers == other.num_workers, ] )
[docs] @staticmethod def default_worker_config( num_workers: int = 4, data_parallel_group: Optional[torch.distributed.ProcessGroup] = None ) -> "WorkerConfig": """Returns the default worker config using torch distributed if available. If torch distributed is not available, a single local rank is assumed.""" if torch.distributed.is_available() and torch.distributed.is_initialized(): rank = torch.distributed.get_rank(data_parallel_group) world_size = torch.distributed.get_world_size(data_parallel_group) else: rank = 0 world_size = 1 return WorkerConfig( rank=rank, world_size=world_size, num_workers=num_workers, data_parallel_group=data_parallel_group, )
[docs] def rank_worker_id(self) -> int: """Returns the self worker id within the current rank.""" if self._worker_override_global_rank: assert self.worker_id_offset == 0 return self._worker_override_global_rank % self.num_workers worker_info = torch.utils.data.get_worker_info() if worker_info is None: return self.worker_id_offset assert worker_info.num_workers == self.num_workers return ( worker_info.id + worker_info.num_workers - self.worker_id_offset ) % worker_info.num_workers
[docs] def assert_worker(self): """Checks if the current process is a worker (if configured so), and that the workers are properly configured.""" if self.num_workers <= 1: assert self.rank_worker_id() == 0 else: worker_info = torch.utils.data.get_worker_info() assert worker_info is not None, "Cannot iterate out of worker context" assert worker_info.num_workers == self.num_workers, ( f"Actual number of workers for this rank ({worker_info.num_workers}) does not " f"match the configured number of workers ({self.num_workers})" )
[docs] def global_worker_id(self, override_local_worker_id: Optional[int] = None) -> int: """Returns the global worker index by multiplying the rank with the number of workers. Alternatively, you can override the local worker id. Args: override_local_worker_id (int, optional): The local worker id to override. None means the current worker, which is the default. """ if self._worker_override_global_rank is not None: assert override_local_worker_id is None return self._worker_override_global_rank if override_local_worker_id is not None: return self.rank * self.num_workers + override_local_worker_id else: self.assert_worker() return self.rank * self.num_workers + self.rank_worker_id()
[docs] def worker_seed(self, override_local_worker_id: Optional[int] = None) -> int: """Returns the seed for the current worker (or a specified worker). Base on the current worker id and the seed offset, compute a seed. Alternatively, you can override the local worker id with a fixed one to pregenerate seeds for multiple workers. Args: override_local_worker_id (int, optional): The local worker id to override. None means the current worker, which is the default. """ if self.num_workers == 0: # If we are not using workers, different ranks should still get a different seed global_worker_id = self.rank else: global_worker_id = self.global_worker_id(override_local_worker_id) seed_offset = self.seed_offset seed_hash = hashlib.sha1(f"{global_worker_id},{seed_offset}".encode("utf-8")).digest() return int.from_bytes(seed_hash, byteorder="big", signed=False) & 0xFFFFFFFF
[docs] def config(self) -> Dict[str, Any]: return { "rank": self.rank, "world_size": self.world_size, "num_workers": self.num_workers, "data_parallel_group": ( self.data_parallel_group.size() if self.data_parallel_group else None ), }
[docs] def should_log(self, level: int) -> bool: return level <= self.worker_log_level
[docs] def worker_log(self, data: dict) -> None: """Logs the given data to the worker debug file.""" if self.worker_debug_path is None: print(json.dumps(data) + "\n", end="", flush=True) else: in_worker = torch.utils.data.get_worker_info() is not None # Additional "worker" with rank_worker_id=0 is the main process. All workers have +1 # as their worker_id. worker_id = ( self.rank * (self.num_workers + 1) + self.rank_worker_id() + (1 if in_worker else 0) ) if self._worker_debug_file is None or self._worker_debug_file_worker_id != worker_id: if self._worker_debug_file is not None: self._worker_debug_file.close() path = EPath( self.worker_debug_path.format( worker_id=worker_id, pid=multiprocessing.current_process().ident ) ) path.parent.mkdir(exist_ok=True, parents=True) self._worker_debug_file = path.open("w") self._worker_debug_file_worker_id = worker_id self._worker_debug_file.write(json.dumps(data) + "\n") self._worker_debug_file.flush()