1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import json
17import logging
18import os
19import tempfile
20from typing import Any, Dict, Optional
21
22import yaml
23
24from .cache import DEFAULT_CACHE_SIZE_MB, CacheConfig, CacheManager
25from .instrumentation import setup_opentelemetry
26from .providers import ManifestMetadataProvider
27from .schema import validate_config
28from .types import (
29 DEFAULT_POSIX_PROFILE,
30 DEFAULT_POSIX_PROFILE_NAME,
31 DEFAULT_RETRY_ATTEMPTS,
32 DEFAULT_RETRY_DELAY,
33 CredentialsProvider,
34 MetadataProvider,
35 ProviderBundle,
36 RetryConfig,
37 StorageProvider,
38 StorageProviderConfig,
39)
40from .utils import expand_env_vars, import_class, merge_dictionaries_no_overwrite
41from .rclone import read_rclone_config, DEFAULT_RCLONE_CONFIG_FILE_SEARCH_PATHS
42
43STORAGE_PROVIDER_MAPPING = {
44 "file": "PosixFileStorageProvider",
45 "s3": "S3StorageProvider",
46 "gcs": "GoogleStorageProvider",
47 "oci": "OracleStorageProvider",
48 "azure": "AzureBlobStorageProvider",
49 "ais": "AIStoreStorageProvider",
50 # Map swiftstack to S3StorageProvider for now
51 "s8k": "S3StorageProvider",
52}
53
54CREDENTIALS_PROVIDER_MAPPING = {
55 "S3Credentials": "StaticS3CredentialsProvider",
56 "AzureCredentials": "StaticAzureCredentialsProvider",
57 "AISCredentials": "StaticAISCredentialProvider",
58}
59
60DEFAULT_MSC_CONFIG_FILE_SEARCH_PATHS = (
61 # Yaml
62 "/etc/msc_config.yaml",
63 os.path.join(os.getenv("HOME", ""), ".config", "msc", "config.yaml"),
64 os.path.join(os.getenv("HOME", ""), ".msc_config.yaml"),
65 # Json
66 "/etc/msc_config.json",
67 os.path.join(os.getenv("HOME", ""), ".config", "msc", "config.json"),
68 os.path.join(os.getenv("HOME", ""), ".msc_config.json"),
69)
70
71PACKAGE_NAME = "multistorageclient"
72
73logger = logging.Logger(__name__)
74
75
76class SimpleProviderBundle(ProviderBundle):
77 def __init__(
78 self,
79 storage_provider_config: StorageProviderConfig,
80 credentials_provider: Optional[CredentialsProvider] = None,
81 metadata_provider: Optional[MetadataProvider] = None,
82 ):
83 self._storage_provider_config = storage_provider_config
84 self._credentials_provider = credentials_provider
85 self._metadata_provider = metadata_provider
86
87 @property
88 def storage_provider_config(self) -> StorageProviderConfig:
89 return self._storage_provider_config
90
91 @property
92 def credentials_provider(self) -> Optional[CredentialsProvider]:
93 return self._credentials_provider
94
95 @property
96 def metadata_provider(self) -> Optional[MetadataProvider]:
97 return self._metadata_provider
98
99
100class StorageClientConfigLoader:
101 def __init__(
102 self,
103 config_dict: Dict[str, Any],
104 profile: str = DEFAULT_POSIX_PROFILE_NAME,
105 provider_bundle: Optional[ProviderBundle] = None,
106 ) -> None:
107 """
108 Initializes a :py:class:`StorageClientConfigLoader` to create a
109 StorageClientConfig. Components are built using the ``config_dict`` and
110 profile, but a pre-built provider_bundle takes precedence.
111
112 :param config_dict: Dictionary of configuration options.
113 :param profile: Name of profile in ``config_dict`` to use to build configuration.
114 :param provider_bundle: Optional pre-built :py:class:`multistorageclient.types.ProviderBundle`, takes precedence over ``config_dict``.
115 """
116 # ProviderBundle takes precedence
117 self._provider_bundle = provider_bundle
118
119 # Interpolates all environment variables into actual values.
120 config_dict = expand_env_vars(config_dict)
121
122 self._profiles = config_dict.get("profiles", {})
123
124 if DEFAULT_POSIX_PROFILE_NAME not in self._profiles:
125 # Assign the default POSIX profile
126 self._profiles[DEFAULT_POSIX_PROFILE_NAME] = DEFAULT_POSIX_PROFILE["profiles"][DEFAULT_POSIX_PROFILE_NAME]
127 else:
128 # Cannot override default POSIX profile
129 storage_provider_type = (
130 self._profiles[DEFAULT_POSIX_PROFILE_NAME].get("storage_provider", {}).get("type", None)
131 )
132 if storage_provider_type != "file":
133 raise ValueError(
134 f'Cannot override "{DEFAULT_POSIX_PROFILE_NAME}" profile with storage provider type '
135 f'"{storage_provider_type}"; expected "file".'
136 )
137
138 profile_dict = self._profiles.get(profile)
139
140 if not profile_dict:
141 raise ValueError(f"Profile {profile} not found; available profiles: {list(self._profiles.keys())}")
142
143 self._profile = profile
144 self._profile_dict = profile_dict
145 self._opentelemetry_dict = config_dict.get("opentelemetry", None)
146 self._cache_dict = config_dict.get("cache", None)
147
148 def _build_storage_provider(
149 self,
150 storage_provider_name: str,
151 storage_options: Optional[Dict[str, Any]],
152 credentials_provider: Optional[CredentialsProvider] = None,
153 ) -> StorageProvider:
154 if storage_options is None:
155 storage_options = {}
156 if storage_provider_name not in STORAGE_PROVIDER_MAPPING:
157 raise ValueError(
158 f"Storage provider {storage_provider_name} is not supported. "
159 f"Supported providers are: {list(STORAGE_PROVIDER_MAPPING.keys())}"
160 )
161 if credentials_provider:
162 storage_options["credentials_provider"] = credentials_provider
163 class_name = STORAGE_PROVIDER_MAPPING[storage_provider_name]
164 module_name = ".providers"
165 cls = import_class(class_name, module_name, PACKAGE_NAME)
166 return cls(**storage_options)
167
168 def _build_credentials_provider(
169 self, credentials_provider_dict: Optional[Dict[str, Any]]
170 ) -> Optional[CredentialsProvider]:
171 """
172 Initializes the CredentialsProvider based on the provided dictionary.
173 """
174 if not credentials_provider_dict:
175 return None
176
177 if credentials_provider_dict["type"] not in CREDENTIALS_PROVIDER_MAPPING:
178 # Fully qualified class path case
179 class_type = credentials_provider_dict["type"]
180 module_name, class_name = class_type.rsplit(".", 1)
181 cls = import_class(class_name, module_name)
182 else:
183 # Mapped class name case
184 class_name = CREDENTIALS_PROVIDER_MAPPING[credentials_provider_dict["type"]]
185 module_name = ".providers"
186 cls = import_class(class_name, module_name, PACKAGE_NAME)
187
188 options = credentials_provider_dict.get("options", {})
189 return cls(**options)
190
191 def _build_provider_bundle_from_config(self, profile_dict: Dict[str, Any]) -> ProviderBundle:
192 # Initialize CredentialsProvider
193 credentials_provider_dict = profile_dict.get("credentials_provider", None)
194 credentials_provider = self._build_credentials_provider(credentials_provider_dict=credentials_provider_dict)
195
196 # Initialize StorageProvider
197 storage_provider_dict = profile_dict.get("storage_provider", None)
198 if storage_provider_dict:
199 storage_provider_name = storage_provider_dict["type"]
200 storage_options = storage_provider_dict.get("options", {})
201 else:
202 raise ValueError("Missing storage_provider in the config.")
203
204 # Initialize MetadataProvider
205 metadata_provider_dict = profile_dict.get("metadata_provider", None)
206 metadata_provider = None
207 if metadata_provider_dict:
208 if metadata_provider_dict["type"] == "manifest":
209 metadata_options = metadata_provider_dict.get("options", {})
210 # If MetadataProvider has a reference to a different storage provider profile
211 storage_provider_profile = metadata_options.pop("storage_provider_profile", None)
212 if storage_provider_profile:
213 storage_profile_dict = self._profiles.get(storage_provider_profile)
214 if not storage_profile_dict:
215 raise ValueError(
216 f"Profile '{storage_provider_profile}' referenced by "
217 f"storage_provider_profile does not exist."
218 )
219
220 # Check if metadata provider is configured for this profile
221 # NOTE: The storage profile for manifests does not support metadata provider (at the moment).
222 local_metadata_provider_dict = storage_profile_dict.get("metadata_provider", None)
223 if local_metadata_provider_dict:
224 raise ValueError(
225 f"Found metadata_provider for profile '{storage_provider_profile}'. "
226 f"This is not supported for storage profiles used by manifests.'"
227 )
228
229 # Initialize CredentialsProvider
230 local_creds_provider_dict = storage_profile_dict.get("credentials_provider", None)
231 local_creds_provider = self._build_credentials_provider(
232 credentials_provider_dict=local_creds_provider_dict
233 )
234
235 # Initialize StorageProvider
236 local_storage_provider_dict = storage_profile_dict.get("storage_provider", None)
237 if local_storage_provider_dict:
238 local_name = local_storage_provider_dict["type"]
239 local_storage_options = local_storage_provider_dict.get("options", {})
240 else:
241 raise ValueError("Missing storage_provider in the config.")
242
243 storage_provider = self._build_storage_provider(
244 local_name, local_storage_options, local_creds_provider
245 )
246 else:
247 storage_provider = self._build_storage_provider(
248 storage_provider_name, storage_options, credentials_provider
249 )
250
251 metadata_provider = ManifestMetadataProvider(storage_provider, **metadata_options)
252 else:
253 class_type = metadata_provider_dict["type"]
254 if "." not in class_type:
255 raise ValueError(
256 f"Expected a fully qualified class name (e.g., 'module.ClassName'); got '{class_type}'."
257 )
258 module_name, class_name = class_type.rsplit(".", 1)
259 cls = import_class(class_name, module_name)
260 options = metadata_provider_dict.get("options", {})
261 metadata_provider = cls(**options)
262
263 return SimpleProviderBundle(
264 storage_provider_config=StorageProviderConfig(storage_provider_name, storage_options),
265 credentials_provider=credentials_provider,
266 metadata_provider=metadata_provider,
267 )
268
269 def _build_provider_bundle_from_extension(self, provider_bundle_dict: Dict[str, Any]) -> ProviderBundle:
270 class_type = provider_bundle_dict["type"]
271 module_name, class_name = class_type.rsplit(".", 1)
272 cls = import_class(class_name, module_name)
273 options = provider_bundle_dict.get("options", {})
274 return cls(**options)
275
276 def _build_provider_bundle(self) -> ProviderBundle:
277 if self._provider_bundle:
278 return self._provider_bundle # Return if previously provided.
279
280 # Load 3rd party extension
281 provider_bundle_dict = self._profile_dict.get("provider_bundle", None)
282 if provider_bundle_dict:
283 return self._build_provider_bundle_from_extension(provider_bundle_dict)
284
285 return self._build_provider_bundle_from_config(self._profile_dict)
286
287 def build_config(self) -> "StorageClientConfig":
288 bundle = self._build_provider_bundle()
289 storage_provider = self._build_storage_provider(
290 bundle.storage_provider_config.type, bundle.storage_provider_config.options, bundle.credentials_provider
291 )
292
293 # Cache Config
294 cache_config: Optional[CacheConfig] = None
295 if self._cache_dict is not None:
296 tempdir = tempfile.gettempdir()
297 default_location = os.path.join(tempdir, ".msc_cache")
298 cache_location = self._cache_dict.get("location", default_location)
299 size_mb = self._cache_dict.get("size_mb", DEFAULT_CACHE_SIZE_MB)
300 os.makedirs(cache_location, exist_ok=True)
301 use_etag = self._cache_dict.get("use_etag", False)
302 cache_config = CacheConfig(location=cache_location, size_mb=size_mb, use_etag=use_etag)
303
304 # retry options
305 retry_config_dict = self._profile_dict.get("retry", None)
306 if retry_config_dict:
307 attempts = retry_config_dict.get("attempts", DEFAULT_RETRY_ATTEMPTS)
308 delay = retry_config_dict.get("delay", DEFAULT_RETRY_DELAY)
309 retry_config = RetryConfig(attempts=attempts, delay=delay)
310 else:
311 retry_config = RetryConfig(attempts=DEFAULT_RETRY_ATTEMPTS, delay=DEFAULT_RETRY_DELAY)
312
313 # set up OpenTelemetry providers once per process
314 if self._opentelemetry_dict:
315 setup_opentelemetry(self._opentelemetry_dict)
316
317 return StorageClientConfig(
318 profile=self._profile,
319 storage_provider=storage_provider,
320 credentials_provider=bundle.credentials_provider,
321 metadata_provider=bundle.metadata_provider,
322 cache_config=cache_config,
323 retry_config=retry_config,
324 )
325
326
[docs]
327class StorageClientConfig:
328 """
329 Configuration class for the :py:class:`multistorageclient.StorageClient`.
330 """
331
332 profile: str
333 storage_provider: StorageProvider
334 credentials_provider: Optional[CredentialsProvider]
335 metadata_provider: Optional[MetadataProvider]
336 cache_config: Optional[CacheConfig]
337 cache_manager: Optional[CacheManager]
338 retry_config: Optional[RetryConfig]
339
340 _config_dict: Dict[str, Any]
341
342 def __init__(
343 self,
344 profile: str,
345 storage_provider: StorageProvider,
346 credentials_provider: Optional[CredentialsProvider] = None,
347 metadata_provider: Optional[MetadataProvider] = None,
348 cache_config: Optional[CacheConfig] = None,
349 retry_config: Optional[RetryConfig] = None,
350 ):
351 self.profile = profile
352 self.storage_provider = storage_provider
353 self.credentials_provider = credentials_provider
354 self.metadata_provider = metadata_provider
355 self.cache_config = cache_config
356 self.retry_config = retry_config
357 self.cache_manager = CacheManager(profile, cache_config) if cache_config else None
358
[docs]
359 @staticmethod
360 def from_json(config_json: str, profile: str = DEFAULT_POSIX_PROFILE_NAME) -> "StorageClientConfig":
361 config_dict = json.loads(config_json)
362 return StorageClientConfig.from_dict(config_dict, profile)
363
[docs]
364 @staticmethod
365 def from_yaml(config_yaml: str, profile: str = DEFAULT_POSIX_PROFILE_NAME) -> "StorageClientConfig":
366 config_dict = yaml.safe_load(config_yaml)
367 return StorageClientConfig.from_dict(config_dict, profile)
368
[docs]
369 @staticmethod
370 def from_dict(config_dict: Dict[str, Any], profile: str = DEFAULT_POSIX_PROFILE_NAME) -> "StorageClientConfig":
371 # Validate the config file with predefined JSON schema
372 validate_config(config_dict)
373
374 # Load config
375 loader = StorageClientConfigLoader(config_dict, profile)
376 config = loader.build_config()
377 config._config_dict = config_dict
378
379 return config
380
[docs]
381 @staticmethod
382 def from_file(profile: str = DEFAULT_POSIX_PROFILE_NAME) -> "StorageClientConfig":
383 msc_config_file = os.getenv("MSC_CONFIG", None)
384
385 # Search config paths
386 if msc_config_file is None:
387 for filename in DEFAULT_MSC_CONFIG_FILE_SEARCH_PATHS:
388 if os.path.exists(filename):
389 msc_config_file = filename
390 break
391
392 msc_config_dict = {}
393
394 # Parse MSC config file.
395 if msc_config_file:
396 with open(msc_config_file, "r", encoding="utf-8") as f:
397 content = f.read()
398 if msc_config_file.endswith(".json"):
399 msc_config_dict = json.loads(content)
400 else:
401 msc_config_dict = yaml.safe_load(content)
402
403 # Parse rclone config file.
404 rclone_config_dict, rclone_config_file = read_rclone_config()
405
406 # If no config file is found, use a default profile.
407 if not msc_config_file and not rclone_config_file:
408 search_paths = DEFAULT_MSC_CONFIG_FILE_SEARCH_PATHS + DEFAULT_RCLONE_CONFIG_FILE_SEARCH_PATHS
409 logger.warning(
410 "Cannot find the MSC config or rclone config file in any of the locations: %s",
411 search_paths,
412 )
413
414 return StorageClientConfig.from_dict(DEFAULT_POSIX_PROFILE, profile=profile)
415
416 # Merge config files.
417 merged_config, conflicted_keys = merge_dictionaries_no_overwrite(msc_config_dict, rclone_config_dict)
418 if conflicted_keys:
419 raise ValueError(
420 f'Conflicting keys found in configuration files "{msc_config_file}" and "{rclone_config_file}: {conflicted_keys}'
421 )
422
423 return StorageClientConfig.from_dict(merged_config, profile)
424
[docs]
425 @staticmethod
426 def from_provider_bundle(config_dict: Dict[str, Any], provider_bundle: ProviderBundle) -> "StorageClientConfig":
427 loader = StorageClientConfigLoader(config_dict, provider_bundle=provider_bundle)
428 return loader.build_config()
429
430 def __getstate__(self) -> Dict[str, Any]:
431 state = self.__dict__.copy()
432 del state["credentials_provider"]
433 del state["storage_provider"]
434 del state["metadata_provider"]
435 del state["cache_manager"]
436 return state
437
438 def __setstate__(self, state: Dict[str, Any]) -> None:
439 loader = StorageClientConfigLoader(state["_config_dict"], state["profile"])
440 new_config = loader.build_config()
441 self.storage_provider = new_config.storage_provider
442 self.credentials_provider = new_config.credentials_provider
443 self.metadata_provider = new_config.metadata_provider
444 self.cache_config = new_config.cache_config
445 self.retry_config = new_config.retry_config
446 self.cache_manager = new_config.cache_manager