1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import json
17import logging
18import os
19import tempfile
20from typing import Any, Dict, Optional
21
22import yaml
23
24from .cache import DEFAULT_CACHE_SIZE_MB, CacheConfig, CacheManager
25from .instrumentation import setup_opentelemetry
26from .providers.manifest_metadata import ManifestMetadataProvider
27from .schema import validate_config
28from .types import (
29 DEFAULT_POSIX_PROFILE,
30 DEFAULT_POSIX_PROFILE_NAME,
31 DEFAULT_RETRY_ATTEMPTS,
32 DEFAULT_RETRY_DELAY,
33 CredentialsProvider,
34 MetadataProvider,
35 ProviderBundle,
36 RetryConfig,
37 StorageProvider,
38 StorageProviderConfig,
39)
40from .utils import expand_env_vars, import_class, merge_dictionaries_no_overwrite
41from .rclone import read_rclone_config
42
43STORAGE_PROVIDER_MAPPING = {
44 "file": "PosixFileStorageProvider",
45 "s3": "S3StorageProvider",
46 "gcs": "GoogleStorageProvider",
47 "oci": "OracleStorageProvider",
48 "azure": "AzureBlobStorageProvider",
49 "ais": "AIStoreStorageProvider",
50 "s8k": "S8KStorageProvider",
51}
52
53CREDENTIALS_PROVIDER_MAPPING = {
54 "S3Credentials": "StaticS3CredentialsProvider",
55 "AzureCredentials": "StaticAzureCredentialsProvider",
56 "AISCredentials": "StaticAISCredentialProvider",
57}
58
59DEFAULT_MSC_CONFIG_FILE_SEARCH_PATHS = (
60 # Yaml
61 "/etc/msc_config.yaml",
62 os.path.join(os.getenv("HOME", ""), ".config", "msc", "config.yaml"),
63 os.path.join(os.getenv("HOME", ""), ".msc_config.yaml"),
64 # Json
65 "/etc/msc_config.json",
66 os.path.join(os.getenv("HOME", ""), ".config", "msc", "config.json"),
67 os.path.join(os.getenv("HOME", ""), ".msc_config.json"),
68)
69
70PACKAGE_NAME = "multistorageclient"
71
72logger = logging.Logger(__name__)
73
74
75class SimpleProviderBundle(ProviderBundle):
76 def __init__(
77 self,
78 storage_provider_config: StorageProviderConfig,
79 credentials_provider: Optional[CredentialsProvider] = None,
80 metadata_provider: Optional[MetadataProvider] = None,
81 ):
82 self._storage_provider_config = storage_provider_config
83 self._credentials_provider = credentials_provider
84 self._metadata_provider = metadata_provider
85
86 @property
87 def storage_provider_config(self) -> StorageProviderConfig:
88 return self._storage_provider_config
89
90 @property
91 def credentials_provider(self) -> Optional[CredentialsProvider]:
92 return self._credentials_provider
93
94 @property
95 def metadata_provider(self) -> Optional[MetadataProvider]:
96 return self._metadata_provider
97
98
99class StorageClientConfigLoader:
100 def __init__(
101 self,
102 config_dict: Dict[str, Any],
103 profile: str = DEFAULT_POSIX_PROFILE_NAME,
104 provider_bundle: Optional[ProviderBundle] = None,
105 ) -> None:
106 """
107 Initializes a :py:class:`StorageClientConfigLoader` to create a
108 StorageClientConfig. Components are built using the ``config_dict`` and
109 profile, but a pre-built provider_bundle takes precedence.
110
111 :param config_dict: Dictionary of configuration options.
112 :param profile: Name of profile in ``config_dict`` to use to build configuration.
113 :param provider_bundle: Optional pre-built :py:class:`multistorageclient.types.ProviderBundle`, takes precedence over ``config_dict``.
114 """
115 # ProviderBundle takes precedence
116 self._provider_bundle = provider_bundle
117
118 # Interpolates all environment variables into actual values.
119 config_dict = expand_env_vars(config_dict)
120
121 self._profiles = config_dict.get("profiles", {})
122
123 if DEFAULT_POSIX_PROFILE_NAME not in self._profiles:
124 # Assign the default POSIX profile
125 self._profiles[DEFAULT_POSIX_PROFILE_NAME] = DEFAULT_POSIX_PROFILE["profiles"][DEFAULT_POSIX_PROFILE_NAME]
126 else:
127 # Cannot override default POSIX profile
128 storage_provider_type = (
129 self._profiles[DEFAULT_POSIX_PROFILE_NAME].get("storage_provider", {}).get("type", None)
130 )
131 if storage_provider_type != "file":
132 raise ValueError(
133 f'Cannot override "{DEFAULT_POSIX_PROFILE_NAME}" profile with storage provider type '
134 f'"{storage_provider_type}"; expected "file".'
135 )
136
137 profile_dict = self._profiles.get(profile)
138
139 if not profile_dict:
140 raise ValueError(f"Profile {profile} not found; available profiles: {list(self._profiles.keys())}")
141
142 self._profile = profile
143 self._profile_dict = profile_dict
144 self._opentelemetry_dict = config_dict.get("opentelemetry", None)
145 self._cache_dict = config_dict.get("cache", None)
146
147 def _build_storage_provider(
148 self,
149 storage_provider_name: str,
150 storage_options: Optional[Dict[str, Any]],
151 credentials_provider: Optional[CredentialsProvider] = None,
152 ) -> StorageProvider:
153 if storage_options is None:
154 storage_options = {}
155 if storage_provider_name not in STORAGE_PROVIDER_MAPPING:
156 raise ValueError(
157 f"Storage provider {storage_provider_name} is not supported. "
158 f"Supported providers are: {list(STORAGE_PROVIDER_MAPPING.keys())}"
159 )
160 if credentials_provider:
161 storage_options["credentials_provider"] = credentials_provider
162 class_name = STORAGE_PROVIDER_MAPPING[storage_provider_name]
163 module_name = ".providers"
164 cls = import_class(class_name, module_name, PACKAGE_NAME)
165 return cls(**storage_options)
166
167 def _build_credentials_provider(
168 self,
169 credentials_provider_dict: Optional[Dict[str, Any]],
170 storage_options: Optional[Dict[str, Any]] = None,
171 ) -> Optional[CredentialsProvider]:
172 """
173 Initializes the CredentialsProvider based on the provided dictionary.
174
175 Args:
176 credentials_provider_dict: Dictionary containing credentials provider configuration
177 storage_options: Storage provider options required by some credentials providers to scope the credentials.
178 """
179 if not credentials_provider_dict:
180 return None
181
182 if credentials_provider_dict["type"] not in CREDENTIALS_PROVIDER_MAPPING:
183 # Fully qualified class path case
184 class_type = credentials_provider_dict["type"]
185 module_name, class_name = class_type.rsplit(".", 1)
186 cls = import_class(class_name, module_name)
187 else:
188 # Mapped class name case
189 class_name = CREDENTIALS_PROVIDER_MAPPING[credentials_provider_dict["type"]]
190 module_name = ".providers"
191 cls = import_class(class_name, module_name, PACKAGE_NAME)
192
193 # Propagate storage provider options to credentials provider since they may be
194 # required by some credentials providers to scope the credentials.
195 import inspect
196
197 init_params = list(inspect.signature(cls.__init__).parameters)[1:] # skip 'self'
198 options = credentials_provider_dict.get("options", {})
199 if storage_options:
200 for storage_provider_option in storage_options.keys():
201 if storage_provider_option in init_params and storage_provider_option not in options:
202 options[storage_provider_option] = storage_options[storage_provider_option]
203
204 return cls(**options)
205
206 def _build_provider_bundle_from_config(self, profile_dict: Dict[str, Any]) -> ProviderBundle:
207 # Initialize StorageProvider
208 storage_provider_dict = profile_dict.get("storage_provider", None)
209 if storage_provider_dict:
210 storage_provider_name = storage_provider_dict["type"]
211 storage_options = storage_provider_dict.get("options", {})
212 else:
213 raise ValueError("Missing storage_provider in the config.")
214
215 # Initialize CredentialsProvider
216 # It is prudent to assume that in some cases, the credentials provider
217 # will provide credentials scoped to specific base_path.
218 # So we need to pass the storage_options to the credentials provider.
219 credentials_provider_dict = profile_dict.get("credentials_provider", None)
220 credentials_provider = self._build_credentials_provider(
221 credentials_provider_dict=credentials_provider_dict,
222 storage_options=storage_options,
223 )
224
225 # Initialize MetadataProvider
226 metadata_provider_dict = profile_dict.get("metadata_provider", None)
227 metadata_provider = None
228 if metadata_provider_dict:
229 if metadata_provider_dict["type"] == "manifest":
230 metadata_options = metadata_provider_dict.get("options", {})
231 # If MetadataProvider has a reference to a different storage provider profile
232 storage_provider_profile = metadata_options.pop("storage_provider_profile", None)
233 if storage_provider_profile:
234 storage_profile_dict = self._profiles.get(storage_provider_profile)
235 if not storage_profile_dict:
236 raise ValueError(
237 f"Profile '{storage_provider_profile}' referenced by "
238 f"storage_provider_profile does not exist."
239 )
240
241 # Check if metadata provider is configured for this profile
242 # NOTE: The storage profile for manifests does not support metadata provider (at the moment).
243 local_metadata_provider_dict = storage_profile_dict.get("metadata_provider", None)
244 if local_metadata_provider_dict:
245 raise ValueError(
246 f"Found metadata_provider for profile '{storage_provider_profile}'. "
247 f"This is not supported for storage profiles used by manifests.'"
248 )
249
250 # Initialize StorageProvider
251 local_storage_provider_dict = storage_profile_dict.get("storage_provider", None)
252 if local_storage_provider_dict:
253 local_name = local_storage_provider_dict["type"]
254 local_storage_options = local_storage_provider_dict.get("options", {})
255 else:
256 raise ValueError("Missing storage_provider in the config.")
257
258 # Initialize CredentialsProvider
259 # We need to pass the storage_provider options to the credentials provider.
260 local_creds_provider_dict = storage_profile_dict.get("credentials_provider", None)
261 local_creds_provider = self._build_credentials_provider(
262 credentials_provider_dict=local_creds_provider_dict,
263 storage_options=local_storage_options,
264 )
265
266 storage_provider = self._build_storage_provider(
267 local_name, local_storage_options, local_creds_provider
268 )
269 else:
270 storage_provider = self._build_storage_provider(
271 storage_provider_name, storage_options, credentials_provider
272 )
273
274 metadata_provider = ManifestMetadataProvider(storage_provider, **metadata_options)
275 else:
276 class_type = metadata_provider_dict["type"]
277 if "." not in class_type:
278 raise ValueError(
279 f"Expected a fully qualified class name (e.g., 'module.ClassName'); got '{class_type}'."
280 )
281 module_name, class_name = class_type.rsplit(".", 1)
282 cls = import_class(class_name, module_name)
283 options = metadata_provider_dict.get("options", {})
284 metadata_provider = cls(**options)
285
286 return SimpleProviderBundle(
287 storage_provider_config=StorageProviderConfig(storage_provider_name, storage_options),
288 credentials_provider=credentials_provider,
289 metadata_provider=metadata_provider,
290 )
291
292 def _build_provider_bundle_from_extension(self, provider_bundle_dict: Dict[str, Any]) -> ProviderBundle:
293 class_type = provider_bundle_dict["type"]
294 module_name, class_name = class_type.rsplit(".", 1)
295 cls = import_class(class_name, module_name)
296 options = provider_bundle_dict.get("options", {})
297 return cls(**options)
298
299 def _build_provider_bundle(self) -> ProviderBundle:
300 if self._provider_bundle:
301 return self._provider_bundle # Return if previously provided.
302
303 # Load 3rd party extension
304 provider_bundle_dict = self._profile_dict.get("provider_bundle", None)
305 if provider_bundle_dict:
306 return self._build_provider_bundle_from_extension(provider_bundle_dict)
307
308 return self._build_provider_bundle_from_config(self._profile_dict)
309
310 def build_config(self) -> "StorageClientConfig":
311 bundle = self._build_provider_bundle()
312 storage_provider = self._build_storage_provider(
313 bundle.storage_provider_config.type,
314 bundle.storage_provider_config.options,
315 bundle.credentials_provider,
316 )
317
318 # Cache Config
319 cache_config: Optional[CacheConfig] = None
320 if self._cache_dict is not None:
321 tempdir = tempfile.gettempdir()
322 default_location = os.path.join(tempdir, ".msc_cache")
323 cache_location = self._cache_dict.get("location", default_location)
324 size_mb = self._cache_dict.get("size_mb", DEFAULT_CACHE_SIZE_MB)
325 os.makedirs(cache_location, exist_ok=True)
326 use_etag = self._cache_dict.get("use_etag", False)
327 eviction_policy = self._cache_dict.get("eviction_policy", "fifo")
328 cache_config = CacheConfig(
329 location=cache_location,
330 size_mb=size_mb,
331 use_etag=use_etag,
332 eviction_policy=eviction_policy,
333 )
334
335 # retry options
336 retry_config_dict = self._profile_dict.get("retry", None)
337 if retry_config_dict:
338 attempts = retry_config_dict.get("attempts", DEFAULT_RETRY_ATTEMPTS)
339 delay = retry_config_dict.get("delay", DEFAULT_RETRY_DELAY)
340 retry_config = RetryConfig(attempts=attempts, delay=delay)
341 else:
342 retry_config = RetryConfig(attempts=DEFAULT_RETRY_ATTEMPTS, delay=DEFAULT_RETRY_DELAY)
343
344 # set up OpenTelemetry providers once per process
345 if self._opentelemetry_dict:
346 setup_opentelemetry(self._opentelemetry_dict)
347
348 return StorageClientConfig(
349 profile=self._profile,
350 storage_provider=storage_provider,
351 credentials_provider=bundle.credentials_provider,
352 metadata_provider=bundle.metadata_provider,
353 cache_config=cache_config,
354 retry_config=retry_config,
355 )
356
357
[docs]
358class StorageClientConfig:
359 """
360 Configuration class for the :py:class:`multistorageclient.StorageClient`.
361 """
362
363 profile: str
364 storage_provider: StorageProvider
365 credentials_provider: Optional[CredentialsProvider]
366 metadata_provider: Optional[MetadataProvider]
367 cache_config: Optional[CacheConfig]
368 cache_manager: Optional[CacheManager]
369 retry_config: Optional[RetryConfig]
370
371 _config_dict: Optional[Dict[str, Any]]
372
373 def __init__(
374 self,
375 profile: str,
376 storage_provider: StorageProvider,
377 credentials_provider: Optional[CredentialsProvider] = None,
378 metadata_provider: Optional[MetadataProvider] = None,
379 cache_config: Optional[CacheConfig] = None,
380 retry_config: Optional[RetryConfig] = None,
381 ):
382 self.profile = profile
383 self.storage_provider = storage_provider
384 self.credentials_provider = credentials_provider
385 self.metadata_provider = metadata_provider
386 self.cache_config = cache_config
387 self.retry_config = retry_config
388 self.cache_manager = CacheManager(profile, cache_config) if cache_config else None
389
[docs]
390 @staticmethod
391 def from_json(config_json: str, profile: str = DEFAULT_POSIX_PROFILE_NAME) -> "StorageClientConfig":
392 config_dict = json.loads(config_json)
393 return StorageClientConfig.from_dict(config_dict, profile)
394
[docs]
395 @staticmethod
396 def from_yaml(config_yaml: str, profile: str = DEFAULT_POSIX_PROFILE_NAME) -> "StorageClientConfig":
397 config_dict = yaml.safe_load(config_yaml)
398 return StorageClientConfig.from_dict(config_dict, profile)
399
[docs]
400 @staticmethod
401 def from_dict(config_dict: Dict[str, Any], profile: str = DEFAULT_POSIX_PROFILE_NAME) -> "StorageClientConfig":
402 # Validate the config file with predefined JSON schema
403 validate_config(config_dict)
404
405 # Load config
406 loader = StorageClientConfigLoader(config_dict, profile)
407 config = loader.build_config()
408 config._config_dict = config_dict
409
410 return config
411
[docs]
412 @staticmethod
413 def from_file(profile: str = DEFAULT_POSIX_PROFILE_NAME) -> "StorageClientConfig":
414 msc_config_file = os.getenv("MSC_CONFIG", None)
415
416 # Search config paths
417 if msc_config_file is None:
418 for filename in DEFAULT_MSC_CONFIG_FILE_SEARCH_PATHS:
419 if os.path.exists(filename):
420 msc_config_file = filename
421 break
422
423 msc_config_dict = {}
424
425 # Parse MSC config file.
426 if msc_config_file:
427 with open(msc_config_file, "r", encoding="utf-8") as f:
428 content = f.read()
429 if msc_config_file.endswith(".json"):
430 msc_config_dict = json.loads(content)
431 else:
432 msc_config_dict = yaml.safe_load(content)
433
434 # Parse rclone config file.
435 rclone_config_dict, rclone_config_file = read_rclone_config()
436
437 # If no config file is found, use a default profile.
438 if not msc_config_file and not rclone_config_file:
439 return StorageClientConfig.from_dict(DEFAULT_POSIX_PROFILE, profile=profile)
440
441 # Merge config files.
442 merged_config, conflicted_keys = merge_dictionaries_no_overwrite(msc_config_dict, rclone_config_dict)
443 if conflicted_keys:
444 raise ValueError(
445 f'Conflicting keys found in configuration files "{msc_config_file}" and "{rclone_config_file}: {conflicted_keys}'
446 )
447
448 return StorageClientConfig.from_dict(merged_config, profile)
449
[docs]
450 @staticmethod
451 def from_provider_bundle(config_dict: Dict[str, Any], provider_bundle: ProviderBundle) -> "StorageClientConfig":
452 loader = StorageClientConfigLoader(config_dict, provider_bundle=provider_bundle)
453 config = loader.build_config()
454 config._config_dict = None # Explicitly mark as None to avoid confusing pickling errors
455 return config
456
457 def __getstate__(self) -> Dict[str, Any]:
458 state = self.__dict__.copy()
459 if not state.get("_config_dict"):
460 raise ValueError("StorageClientConfig is not serializable")
461 del state["credentials_provider"]
462 del state["storage_provider"]
463 del state["metadata_provider"]
464 del state["cache_manager"]
465 return state
466
467 def __setstate__(self, state: Dict[str, Any]) -> None:
468 self.profile = state["profile"]
469 loader = StorageClientConfigLoader(state["_config_dict"], self.profile)
470 new_config = loader.build_config()
471 self.storage_provider = new_config.storage_provider
472 self.credentials_provider = new_config.credentials_provider
473 self.metadata_provider = new_config.metadata_provider
474 self.cache_config = new_config.cache_config
475 self.retry_config = new_config.retry_config
476 self.cache_manager = new_config.cache_manager