Source code for multistorageclient.contrib.hydra

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16"""
 17Hydra ConfigSource plugin for Multi-Storage Client.
 18
 19This module provides a ConfigSource implementation that allows Hydra to load
 20configuration files from remote storage systems using Multi-Storage Client.
 21"""
 22
 23import logging
 24from typing import List, Optional
 25
 26from hydra.core.config_search_path import ConfigSearchPath, SearchPathElement
 27from hydra.core.object_type import ObjectType
 28from hydra.plugins.config_source import ConfigLoadError, ConfigResult, ConfigSource
 29from hydra.plugins.search_path_plugin import SearchPathPlugin
 30from omegaconf import OmegaConf
 31
 32import multistorageclient as msc
 33from multistorageclient.shortcuts import resolve_storage_client
 34from multistorageclient.types import MSC_PROTOCOL, MSC_PROTOCOL_NAME
 35from multistorageclient.utils import join_paths
 36
 37logger = logging.getLogger(__name__)
 38
 39
[docs] 40class MSCConfigSource(ConfigSource): 41 """ 42 A Hydra :py:class:`hydra.plugins.config_source.ConfigSource` that uses Multi-Storage Client to read configuration files from remote storage systems. 43 44 Supports loading configs from S3, GCS, Azure Blob Storage, and other MSC-supported storage backends. 45 Must be used in conjunction with :py:class:`MSCSearchPathPlugin`. 46 """ 47 48 def __init__(self, provider: str, path: str) -> None: 49 """ 50 Initialize the MSC ConfigSource. 51 52 :param provider: The provider name (should be ``main`` if the user specifies config path, or ``msc-universal`` if the user doesn't specify a config path). 53 :param path: The base path for this source. It should be a full MSC URL (e.g., ``msc://dev/configs``) if the user specifies config path, or ``msc://`` if the user doesn't specify a config path and the universal MSC source is used. 54 """ 55 if path.find("://") == -1: 56 path = f"{MSC_PROTOCOL}{path}" 57 super().__init__(provider=provider, path=path) 58 59 # Store the base URL for resolving relative paths 60 self.base_url = path 61 logger.debug(f"Initialized MSCConfigSource with base URL: {self.base_url}") 62
[docs] 63 @staticmethod 64 def scheme() -> str: 65 """ 66 Return the URL scheme for this ConfigSource. 67 68 :return: The scheme string 'msc'. 69 """ 70 return MSC_PROTOCOL_NAME
71 72 def _resolve_full_url(self, config_path: str) -> str: 73 """ 74 Convert config_path to a full ``msc://`` URL. 75 76 :param config_path: Either a relative path, config group reference (e.g., ``database: postgres``), or full ``msc://`` URL. 77 :return: Full ``msc://`` URL that can be passed to ``multistorageclient.open()``. 78 """ 79 if config_path.startswith(MSC_PROTOCOL): 80 return config_path 81 82 # Handle Hydra defaults syntax: "group: config" -> "group/config" 83 if ": " in config_path: 84 group, config_name = config_path.split(": ", 1) 85 config_path = f"{group}/{config_name}" 86 87 # Relative path - join with base URL using MSC's utility 88 return join_paths(self.base_url, config_path) 89
[docs] 90 def load_config(self, config_path: str) -> ConfigResult: 91 """ 92 Load a configuration file from MSC storage. 93 94 :param config_path: Relative path to the config file, or full ``msc://`` URL. 95 :return: The loaded configuration. 96 :raises ConfigLoadError: If the config file cannot be loaded. 97 """ 98 full_url = self._resolve_full_url(config_path) 99 full_url = self._normalize_file_name(full_url) 100 101 try: 102 with msc.open(full_url, "r") as f: 103 header_text = f.read(512) 104 header = ConfigSource._get_header_dict(header_text) 105 f.seek(0) 106 cfg = OmegaConf.load(f) 107 return ConfigResult( 108 provider=self.provider, 109 path=f"{self.scheme()}://{self.path}", 110 config=cfg, 111 header=header, 112 ) 113 except Exception as e: 114 raise ConfigLoadError(f"Failed to load config from {full_url}: {e}")
115
[docs] 116 def available(self) -> bool: 117 """ 118 Check if the MSC config source is available. 119 120 :return: ``True`` if the MSC config source can be accessed, ``False`` otherwise. 121 """ 122 try: 123 # Try to resolve the base URL to see if MSC can handle it 124 resolve_storage_client(self.base_url) 125 return True 126 except Exception: 127 logger.error("MSC config source not available", exc_info=True) 128 return False
129
[docs] 130 def is_group(self, config_path: str) -> bool: 131 """ 132 Check if the given path is a group (directory). 133 134 :param config_path: Relative path or full ``msc://`` URL to check. 135 :return: ``True`` if the path is a directory, ``False`` otherwise. 136 """ 137 full_url = self._resolve_full_url(config_path) 138 139 # Ensure path ends with "/" for directory check 140 full_url = full_url.rstrip("/") + "/" 141 142 try: 143 # Use msc.info() directly to check if path is a directory 144 metadata = msc.info(full_url) 145 return metadata.type == "directory" 146 except FileNotFoundError: 147 return False 148 except Exception: 149 return False
150
[docs] 151 def is_config(self, config_path: str) -> bool: 152 """ 153 Check if the given path is a config file. 154 155 :param config_path: Relative path or full ``msc://`` URL to check. 156 :return: ``True`` if the path is a config file, ``False`` otherwise. 157 """ 158 # If there's a directory with the same name, directory takes precedence 159 if self.is_group(config_path): 160 return False 161 162 full_url = self._resolve_full_url(config_path) 163 full_url = self._normalize_file_name(full_url) 164 165 try: 166 return msc.is_file(full_url) 167 except Exception: 168 return False
169
[docs] 170 def list(self, config_path: str, results_filter: Optional[ObjectType]) -> List[str]: 171 """ 172 List items under the specified config path. 173 174 :param config_path: Relative path or full ``msc://`` URL to list. 175 :param results_filter: Optional filter for the results. 176 :return: List of config names and group names under the specified path. 177 """ 178 full_url = self._resolve_full_url(config_path) 179 files: List[str] = [] 180 181 try: 182 # Use MSC to resolve client and list items directly 183 # In this case, client.list() is simpler than msc.list() because 184 # msc.list() returns keys with full msc:// URLs which would require 185 # more complex path manipulation 186 client, path = resolve_storage_client(full_url) 187 188 # Add trailing slash to ensure we're listing directory contents 189 list_path = path.rstrip("/") + "/" if path else "" 190 191 # Get all items under this path 192 for item in client.list(prefix=list_path, include_directories=True): 193 # Get the relative path from the base path 194 if item.key.startswith(list_path): 195 relative_path = item.key[len(list_path) :] 196 elif path and item.key.startswith(path + "/"): 197 relative_path = item.key[len(path + "/") :] 198 else: 199 continue 200 201 # Skip empty paths or the directory itself 202 if not relative_path or relative_path.endswith("/"): 203 continue 204 205 # Get just the immediate file/directory name (no nested paths) 206 file_name = relative_path.split("/")[0] 207 if not file_name: 208 continue 209 210 # Build the full config path for this item 211 item_path = join_paths(config_path, file_name) if config_path else file_name 212 213 self._list_add_result( 214 files=files, 215 file_path=item_path, 216 file_name=file_name, 217 results_filter=results_filter, 218 ) 219 220 except Exception: 221 logger.error(f"Failed to list MSC path '{full_url}'", exc_info=True) 222 # Return empty list if we can't list the directory 223 224 return sorted(list(set(files)))
225 226 def __repr__(self) -> str: 227 return f"MSCConfigSource(provider={self.provider}, path={self.scheme()}://{self.path})"
228 229
[docs] 230class MSCSearchPathPlugin(SearchPathPlugin): 231 """ 232 A Hydra :py:class:`hydra.plugins.search_path_plugin.SearchPathPlugin` that enables MSC support. 233 234 Fixes MSC URL mangling issues and ensures MSC sources are available for config loading. 235 """ 236
[docs] 237 def manipulate_search_path(self, search_path: ConfigSearchPath) -> None: 238 """ 239 Enable MSC support by fixing mangled URLs and adding universal MSC source. 240 241 Performs two operations: 242 243 1. **Fixes mangled MSC URLs**: CLI path normalization can mangle ``msc://dev/configs`` to ``/current/dir/msc:/dev/configs``. 244 2. **Adds universal MSC source**: Ensures :py:class:`MSCConfigSource` is available to handle any ``msc://`` URLs in config defaults. 245 246 :param search_path: The :py:class:`hydra.core.config_search_path.ConfigSearchPath` to manipulate. 247 """ 248 path_elements = search_path.get_path() 249 250 # Step 1: Fix any mangled MSC URLs from CLI 251 for i, element in enumerate(path_elements): 252 path = element.path 253 254 # Detect mangled MSC URLs: contains "msc:/" but doesn't start with "msc://" 255 if path and "msc:/" in path and not path.startswith(MSC_PROTOCOL): 256 # Extract the MSC URL from the mangled path 257 msc_start = path.find("msc:/") 258 msc_part = path[msc_start:] # Everything from "msc:/" onwards 259 260 # Fix the missing slash: "msc:/profile" -> "msc://profile" 261 if msc_part.startswith("msc:/") and not msc_part.startswith(MSC_PROTOCOL): 262 fixed_url = msc_part.replace("msc:/", MSC_PROTOCOL, 1) 263 else: 264 fixed_url = msc_part 265 266 # Replace the element with a new one containing the fixed URL 267 path_elements[i] = SearchPathElement(element.provider, fixed_url) 268 269 logger.debug(f"Fixed mangled MSC URL: {path}{fixed_url}") 270 271 # Step 2: Ensure there's a universal MSC source for handling any msc:// URLs 272 # Check if there's already an msc:// entry in the search path 273 has_msc_source = any(element.path and element.path.startswith(MSC_PROTOCOL) for element in path_elements) 274 275 if not has_msc_source: 276 # Add a universal MSC source that can handle any msc:// URL 277 # Use a generic base that MSCConfigSource can resolve dynamically 278 search_path.append("msc-universal", MSC_PROTOCOL) 279 logger.debug("Added universal MSC source to search path")