Source code for nvidia_resiliency_ext.inprocess.state

# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
import enum
import os
from typing import Optional


[docs] class Mode(enum.Enum): r''' Indicates operational mode of the current distributed rank. Attributes: INITIALIZED: the :py:class:`State` was initialized, :py:class:`RankAssignment` was not yet performed ACTIVE: the rank calls the wrapped function INACTIVE: the rank is waiting idle TERMINATED: the rank was terminated ''' INITIALIZED = enum.auto() ACTIVE = enum.auto() INACTIVE = enum.auto() TERMINATED = enum.auto()
[docs] @dataclasses.dataclass class State: r''' Represents the current state of the :py:class:`inprocess.Wrapper`. Args: rank: a distributed rank index as seen by the :py:class:`inprocess.Wrapper`, :py:obj:`None` for terminated ranks world_size: a total number of distributed ranks controlled by the :py:class:`inprocess.Wrapper` active_rank: a distributed rank index passed to the wrapped function active_world_size: a total number of distributed ranks passed to the wrapped function initial_rank: an distributed rank index, captured when the :py:class:`Wrapper` was invoked initial_world_size: a total number of initial distributed ranks iteration: index of the current restart iteration mode: operational mode fn_exception: an instance of :py:exc:`Exception` raised by the wrapped function in the current restart iteration, :py:obj:`None` if no exception was raised ''' rank: Optional[int] world_size: int active_rank: Optional[int] = None active_world_size: Optional[int] = None initial_rank: Optional[int] = None initial_world_size: Optional[int] = None iteration: int = 0 mode: Mode = Mode.INITIALIZED fn_exception: Optional[Exception] = None def __post_init__(self): if self.initial_rank is None: self.initial_rank = self.rank if self.initial_world_size is None: self.initial_world_size = self.world_size @classmethod def from_env(cls): rank = int(os.getenv('RANK', 0)) world_size = int(os.getenv('WORLD_SIZE', 1)) return cls( rank=rank, world_size=world_size, initial_rank=rank, initial_world_size=world_size, ) def set_distributed_vars(self): os.environ['RANK'] = str(self.active_rank) os.environ['WORLD_SIZE'] = str(self.active_world_size) def advance(self): self.iteration += 1 self.fn_exception = None def freeze(self): frozen = FrozenState(**dataclasses.asdict(self)) return frozen def copy_from(self, other, fields=None): for field in dataclasses.fields(self): if fields is None or field.name in fields: setattr(self, field.name, getattr(other, field.name))
def freeze_dataclass(cls): fields = [(f.name, f.type, f) for f in dataclasses.fields(cls)] FrozenClass = dataclasses.make_dataclass( cls_name=f'Frozen{cls.__name__}', fields=fields, frozen=True ) return FrozenClass FrozenState = freeze_dataclass(State) FrozenState.__doc__ = r''' :py:class:`inprocess.FrozenState` is identical to :py:class:`inprocess.State`, except all fields are read-only. '''