Source code for flashdreams.infra.config.base

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Config primitives: ``InstantiateConfig`` (``_target`` + ``setup``) and ``derive_config`` patching."""

import copy
from dataclasses import dataclass
from typing import Any, TypeVar


[docs] class PrintableConfig: """Config base class providing a multi-line ``__str__`` for human-readable dumps.""" def __str__(self): lines = [self.__class__.__name__ + ":"] for key, val in vars(self).items(): if isinstance(val, tuple): flattened_val = "[" for item in val: flattened_val += str(item) + "\n" flattened_val = flattened_val.rstrip("\n") val = flattened_val + "]" lines += f"{key}: {str(val)}".split("\n") return "\n ".join(lines)
[docs] @dataclass class InstantiateConfig(PrintableConfig): """Config carrying a ``_target`` class plus its kwargs, instantiable via ``setup``.""" _target: type
[docs] def setup(self, **kwargs: Any) -> Any: """Instantiate the configured object.""" return self._target(self, **kwargs)
T = TypeVar("T")
[docs] def derive_config(base_config: T, **changes: Any) -> T: """Deep-copy a base config and apply nested keyword overrides. Nested ``dict`` values walk into both dataclass attributes and nested dicts; leaf values overwrite directly. Raises ``KeyError`` on unknown paths. Example: .. code-block:: python new_config = derive_config( base_config, tokenizer=WanVAEInterfaceConfig(checkpoint_path=...), dit=dict(len_t=3, checkpoint_path=...), ) """ def _is_patchable_object(x: Any) -> bool: # Object is patchable if it has attribute storage (__dict__). return hasattr(x, "__dict__") def _get_field(target: Any, key: str, path: str) -> Any: if isinstance(target, dict): if key not in target: raise KeyError(f"Unknown key at {path}: {key}") return target[key] if hasattr(target, key): return getattr(target, key) raise KeyError(f"Unknown field at {path}: {type(target).__name__}.{key}") def _set_field(target: Any, key: str, value: Any, path: str) -> None: if isinstance(target, dict): if key not in target: raise KeyError(f"Unknown key at {path}: {key}") target[key] = value return if hasattr(target, key): setattr(target, key, value) return raise KeyError(f"Unknown field at {path}: {type(target).__name__}.{key}") def _recursive_patch( target: Any, patch: dict[str, Any], path: str = "root" ) -> None: # Apply patch recursively to dict/object target. for key, value in patch.items(): current = _get_field(target, key, path) current_path = f"{path}.{key}" if isinstance(value, dict): # Nested patch: current can be dict or object. if isinstance(current, dict) or _is_patchable_object(current): _recursive_patch(current, value, current_path) else: raise TypeError( f"Cannot apply nested dict patch to non-nested field at {current_path} " f"(type={type(current).__name__})" ) else: # Leaf patch: direct assignment. _set_field(target, key, value, path) # Deep-copy base config so original config is never mutated. cfg = copy.deepcopy(base_config) _recursive_patch(cfg, changes, path=type(cfg).__name__) return cfg