1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""
17Hydra ConfigSource plugin for Multi-Storage Client.
18
19This module provides a ConfigSource implementation that allows Hydra to load
20configuration files from remote storage systems using Multi-Storage Client.
21"""
22
23import logging
24from typing import List, Optional
25
26from hydra.core.config_search_path import ConfigSearchPath, SearchPathElement
27from hydra.core.object_type import ObjectType
28from hydra.plugins.config_source import ConfigLoadError, ConfigResult, ConfigSource
29from hydra.plugins.search_path_plugin import SearchPathPlugin
30from omegaconf import OmegaConf
31
32import multistorageclient as msc
33from multistorageclient.shortcuts import resolve_storage_client
34from multistorageclient.types import MSC_PROTOCOL, MSC_PROTOCOL_NAME
35from multistorageclient.utils import join_paths
36
37logger = logging.getLogger(__name__)
38
39
[docs]
40class MSCConfigSource(ConfigSource):
41 """
42 A Hydra :py:class:`hydra.plugins.config_source.ConfigSource` that uses Multi-Storage Client to read configuration files from remote storage systems.
43
44 Supports loading configs from S3, GCS, Azure Blob Storage, and other MSC-supported storage backends.
45 Must be used in conjunction with :py:class:`MSCSearchPathPlugin`.
46 """
47
48 def __init__(self, provider: str, path: str) -> None:
49 """
50 Initialize the MSC ConfigSource.
51
52 :param provider: The provider name (should be ``main`` if the user specifies config path, or ``msc-universal`` if the user doesn't specify a config path).
53 :param path: The base path for this source. It should be a full MSC URL (e.g., ``msc://dev/configs``) if the user specifies config path, or ``msc://`` if the user doesn't specify a config path and the universal MSC source is used.
54 """
55 if path.find("://") == -1:
56 path = f"{MSC_PROTOCOL}{path}"
57 super().__init__(provider=provider, path=path)
58
59 # Store the base URL for resolving relative paths
60 self.base_url = path
61 logger.debug(f"Initialized MSCConfigSource with base URL: {self.base_url}")
62
[docs]
63 @staticmethod
64 def scheme() -> str:
65 """
66 Return the URL scheme for this ConfigSource.
67
68 :return: The scheme string 'msc'.
69 """
70 return MSC_PROTOCOL_NAME
71
72 def _resolve_full_url(self, config_path: str) -> str:
73 """
74 Convert config_path to a full ``msc://`` URL.
75
76 :param config_path: Either a relative path, config group reference (e.g., ``database: postgres``), or full ``msc://`` URL.
77 :return: Full ``msc://`` URL that can be passed to ``multistorageclient.open()``.
78 """
79 if config_path.startswith(MSC_PROTOCOL):
80 return config_path
81
82 # Handle Hydra defaults syntax: "group: config" -> "group/config"
83 if ": " in config_path:
84 group, config_name = config_path.split(": ", 1)
85 config_path = f"{group}/{config_name}"
86
87 # Relative path - join with base URL using MSC's utility
88 return join_paths(self.base_url, config_path)
89
[docs]
90 def load_config(self, config_path: str) -> ConfigResult:
91 """
92 Load a configuration file from MSC storage.
93
94 :param config_path: Relative path to the config file, or full ``msc://`` URL.
95 :return: The loaded configuration.
96 :raises ConfigLoadError: If the config file cannot be loaded.
97 """
98 full_url = self._resolve_full_url(config_path)
99 full_url = self._normalize_file_name(full_url)
100
101 try:
102 with msc.open(full_url, "r") as f:
103 header_text = f.read(512)
104 header = ConfigSource._get_header_dict(header_text)
105 f.seek(0)
106 cfg = OmegaConf.load(f)
107 return ConfigResult(
108 provider=self.provider,
109 path=f"{self.scheme()}://{self.path}",
110 config=cfg,
111 header=header,
112 )
113 except Exception as e:
114 raise ConfigLoadError(f"Failed to load config from {full_url}: {e}")
115
[docs]
116 def available(self) -> bool:
117 """
118 Check if the MSC config source is available.
119
120 :return: ``True`` if the MSC config source can be accessed, ``False`` otherwise.
121 """
122 try:
123 # Try to resolve the base URL to see if MSC can handle it
124 resolve_storage_client(self.base_url)
125 return True
126 except Exception:
127 logger.error("MSC config source not available", exc_info=True)
128 return False
129
[docs]
130 def is_group(self, config_path: str) -> bool:
131 """
132 Check if the given path is a group (directory).
133
134 :param config_path: Relative path or full ``msc://`` URL to check.
135 :return: ``True`` if the path is a directory, ``False`` otherwise.
136 """
137 full_url = self._resolve_full_url(config_path)
138
139 # Ensure path ends with "/" for directory check
140 full_url = full_url.rstrip("/") + "/"
141
142 try:
143 # Use msc.info() directly to check if path is a directory
144 metadata = msc.info(full_url)
145 return metadata.type == "directory"
146 except FileNotFoundError:
147 return False
148 except Exception:
149 return False
150
[docs]
151 def is_config(self, config_path: str) -> bool:
152 """
153 Check if the given path is a config file.
154
155 :param config_path: Relative path or full ``msc://`` URL to check.
156 :return: ``True`` if the path is a config file, ``False`` otherwise.
157 """
158 # If there's a directory with the same name, directory takes precedence
159 if self.is_group(config_path):
160 return False
161
162 full_url = self._resolve_full_url(config_path)
163 full_url = self._normalize_file_name(full_url)
164
165 try:
166 return msc.is_file(full_url)
167 except Exception:
168 return False
169
[docs]
170 def list(self, config_path: str, results_filter: Optional[ObjectType]) -> List[str]:
171 """
172 List items under the specified config path.
173
174 :param config_path: Relative path or full ``msc://`` URL to list.
175 :param results_filter: Optional filter for the results.
176 :return: List of config names and group names under the specified path.
177 """
178 full_url = self._resolve_full_url(config_path)
179 files: List[str] = []
180
181 try:
182 # Use MSC to resolve client and list items directly
183 # In this case, client.list() is simpler than msc.list() because
184 # msc.list() returns keys with full msc:// URLs which would require
185 # more complex path manipulation
186 client, path = resolve_storage_client(full_url)
187
188 # Add trailing slash to ensure we're listing directory contents
189 list_path = path.rstrip("/") + "/" if path else ""
190
191 # Get all items under this path
192 for item in client.list(prefix=list_path, include_directories=True):
193 # Get the relative path from the base path
194 if item.key.startswith(list_path):
195 relative_path = item.key[len(list_path) :]
196 elif path and item.key.startswith(path + "/"):
197 relative_path = item.key[len(path + "/") :]
198 else:
199 continue
200
201 # Skip empty paths or the directory itself
202 if not relative_path or relative_path.endswith("/"):
203 continue
204
205 # Get just the immediate file/directory name (no nested paths)
206 file_name = relative_path.split("/")[0]
207 if not file_name:
208 continue
209
210 # Build the full config path for this item
211 item_path = join_paths(config_path, file_name) if config_path else file_name
212
213 self._list_add_result(
214 files=files,
215 file_path=item_path,
216 file_name=file_name,
217 results_filter=results_filter,
218 )
219
220 except Exception:
221 logger.error(f"Failed to list MSC path '{full_url}'", exc_info=True)
222 # Return empty list if we can't list the directory
223
224 return sorted(list(set(files)))
225
226 def __repr__(self) -> str:
227 return f"MSCConfigSource(provider={self.provider}, path={self.scheme()}://{self.path})"
228
229
[docs]
230class MSCSearchPathPlugin(SearchPathPlugin):
231 """
232 A Hydra :py:class:`hydra.plugins.search_path_plugin.SearchPathPlugin` that enables MSC support.
233
234 Fixes MSC URL mangling issues and ensures MSC sources are available for config loading.
235 """
236
[docs]
237 def manipulate_search_path(self, search_path: ConfigSearchPath) -> None:
238 """
239 Enable MSC support by fixing mangled URLs and adding universal MSC source.
240
241 Performs two operations:
242
243 1. **Fixes mangled MSC URLs**: CLI path normalization can mangle ``msc://dev/configs`` to ``/current/dir/msc:/dev/configs``.
244 2. **Adds universal MSC source**: Ensures :py:class:`MSCConfigSource` is available to handle any ``msc://`` URLs in config defaults.
245
246 :param search_path: The :py:class:`hydra.core.config_search_path.ConfigSearchPath` to manipulate.
247 """
248 path_elements = search_path.get_path()
249
250 # Step 1: Fix any mangled MSC URLs from CLI
251 for i, element in enumerate(path_elements):
252 path = element.path
253
254 # Detect mangled MSC URLs: contains "msc:/" but doesn't start with "msc://"
255 if path and "msc:/" in path and not path.startswith(MSC_PROTOCOL):
256 # Extract the MSC URL from the mangled path
257 msc_start = path.find("msc:/")
258 msc_part = path[msc_start:] # Everything from "msc:/" onwards
259
260 # Fix the missing slash: "msc:/profile" -> "msc://profile"
261 if msc_part.startswith("msc:/") and not msc_part.startswith(MSC_PROTOCOL):
262 fixed_url = msc_part.replace("msc:/", MSC_PROTOCOL, 1)
263 else:
264 fixed_url = msc_part
265
266 # Replace the element with a new one containing the fixed URL
267 path_elements[i] = SearchPathElement(element.provider, fixed_url)
268
269 logger.debug(f"Fixed mangled MSC URL: {path} → {fixed_url}")
270
271 # Step 2: Ensure there's a universal MSC source for handling any msc:// URLs
272 # Check if there's already an msc:// entry in the search path
273 has_msc_source = any(element.path and element.path.startswith(MSC_PROTOCOL) for element in path_elements)
274
275 if not has_msc_source:
276 # Add a universal MSC source that can handle any msc:// URL
277 # Use a generic base that MSCConfigSource can resolve dynamically
278 search_path.append("msc-universal", MSC_PROTOCOL)
279 logger.debug("Added universal MSC source to search path")