Source code for multistorageclient.config

# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import os
import tempfile
from typing import Any, Dict, Optional

import yaml

from .cache import DEFAULT_CACHE_SIZE_MB, CacheConfig, CacheManager
from .instrumentation import setup_opentelemetry
from .providers import ManifestMetadataProvider
from .schema import validate_config
from .types import (
    DEFAULT_POSIX_PROFILE,
    DEFAULT_POSIX_PROFILE_NAME,
    DEFAULT_RETRY_ATTEMPTS,
    DEFAULT_RETRY_DELAY,
    CredentialsProvider,
    MetadataProvider,
    ProviderBundle,
    RetryConfig,
    StorageProvider,
    StorageProviderConfig,
)
from .utils import expand_env_vars, import_class

STORAGE_PROVIDER_MAPPING = {
    "file": "PosixFileStorageProvider",
    "s3": "S3StorageProvider",
    "gcs": "GoogleStorageProvider",
    "oci": "OracleStorageProvider",
    "azure": "AzureBlobStorageProvider",
    "ais": "AIStoreStorageProvider",
    # Map swiftstack to S3StorageProvider for now
    "s8k": "S3StorageProvider",
}

CREDENTIALS_PROVIDER_MAPPING = {
    "S3Credentials": "StaticS3CredentialsProvider",
    "AzureCredentials": "StaticAzureCredentialsProvider",
    "AISCredentials": "StaticAISCredentialProvider",
}

DEFAULT_CONFIG_FILE_SEARCH_PATHS = (
    # Yaml
    "/etc/msc_config.yaml",
    os.path.join(os.getenv("HOME", ""), ".config", "msc", "config.yaml"),
    os.path.join(os.getenv("HOME", ""), ".msc_config.yaml"),
    # Json
    "/etc/msc_config.json",
    os.path.join(os.getenv("HOME", ""), ".config", "msc", "config.json"),
    os.path.join(os.getenv("HOME", ""), ".msc_config.json"),
)

PACKAGE_NAME = "multistorageclient"

logger = logging.Logger(__name__)


class SimpleProviderBundle(ProviderBundle):
    def __init__(
        self,
        storage_provider_config: StorageProviderConfig,
        credentials_provider: Optional[CredentialsProvider] = None,
        metadata_provider: Optional[MetadataProvider] = None,
    ):
        self._storage_provider_config = storage_provider_config
        self._credentials_provider = credentials_provider
        self._metadata_provider = metadata_provider

    @property
    def storage_provider_config(self) -> StorageProviderConfig:
        return self._storage_provider_config

    @property
    def credentials_provider(self) -> Optional[CredentialsProvider]:
        return self._credentials_provider

    @property
    def metadata_provider(self) -> Optional[MetadataProvider]:
        return self._metadata_provider


class StorageClientConfigLoader:
    def __init__(
        self,
        config_dict: Dict[str, Any],
        profile: str = DEFAULT_POSIX_PROFILE_NAME,
        provider_bundle: Optional[ProviderBundle] = None,
    ) -> None:
        """
        Initializes a :py:class:`StorageClientConfigLoader` to create a
        StorageClientConfig. Components are built using the ``config_dict`` and
        profile, but a pre-built provider_bundle takes precedence.

        :param config_dict: Dictionary of configuration options.
        :param profile: Name of profile in ``config_dict`` to use to build configuration.
        :param provider_bundle: Optional pre-built :py:class:`multistorageclient.types.ProviderBundle`, takes precedence over ``config_dict``.
        """
        # ProviderBundle takes precedence
        self._provider_bundle = provider_bundle

        # Interpolates all environment variables into actual values.
        config_dict = expand_env_vars(config_dict)

        self._profiles = config_dict.get("profiles", {})

        if DEFAULT_POSIX_PROFILE_NAME not in self._profiles:
            # Assign the default POSIX profile
            self._profiles[DEFAULT_POSIX_PROFILE_NAME] = DEFAULT_POSIX_PROFILE["profiles"][DEFAULT_POSIX_PROFILE_NAME]
        else:
            # Cannot override default POSIX profile
            storage_provider_type = (
                self._profiles[DEFAULT_POSIX_PROFILE_NAME].get("storage_provider", {}).get("type", None)
            )
            if storage_provider_type != "file":
                raise ValueError(
                    f'Cannot override "{DEFAULT_POSIX_PROFILE_NAME}" profile with storage provider type '
                    f'"{storage_provider_type}"; expected "file".'
                )

        profile_dict = self._profiles.get(profile)

        if not profile_dict:
            raise ValueError(f"Profile {profile} not found; available profiles: {list(self._profiles.keys())}")

        self._profile = profile
        self._profile_dict = profile_dict
        self._opentelemetry_dict = config_dict.get("opentelemetry", None)
        self._cache_dict = config_dict.get("cache", None)

    def _build_storage_provider(
        self,
        storage_provider_name: str,
        storage_options: Optional[Dict[str, Any]],
        credentials_provider: Optional[CredentialsProvider] = None,
    ) -> StorageProvider:
        if storage_options is None:
            storage_options = {}
        if storage_provider_name not in STORAGE_PROVIDER_MAPPING:
            raise ValueError(
                f"Storage provider {storage_provider_name} is not supported. "
                f"Supported providers are: {list(STORAGE_PROVIDER_MAPPING.keys())}"
            )
        if credentials_provider:
            storage_options["credentials_provider"] = credentials_provider
        class_name = STORAGE_PROVIDER_MAPPING[storage_provider_name]
        module_name = ".providers"
        cls = import_class(class_name, module_name, PACKAGE_NAME)
        return cls(**storage_options)

    def _build_credentials_provider(
        self, credentials_provider_dict: Optional[Dict[str, Any]]
    ) -> Optional[CredentialsProvider]:
        """
        Initializes the CredentialsProvider based on the provided dictionary.
        """
        if not credentials_provider_dict:
            return None

        if credentials_provider_dict["type"] not in CREDENTIALS_PROVIDER_MAPPING:
            # Fully qualified class path case
            class_type = credentials_provider_dict["type"]
            module_name, class_name = class_type.rsplit(".", 1)
            cls = import_class(class_name, module_name)
        else:
            # Mapped class name case
            class_name = CREDENTIALS_PROVIDER_MAPPING[credentials_provider_dict["type"]]
            module_name = ".providers"
            cls = import_class(class_name, module_name, PACKAGE_NAME)

        options = credentials_provider_dict.get("options", {})
        return cls(**options)

    def _build_provider_bundle_from_config(self, profile_dict: Dict[str, Any]) -> ProviderBundle:
        # Initialize CredentialsProvider
        credentials_provider_dict = profile_dict.get("credentials_provider", None)
        credentials_provider = self._build_credentials_provider(credentials_provider_dict=credentials_provider_dict)

        # Initialize StorageProvider
        storage_provider_dict = profile_dict.get("storage_provider", None)
        if storage_provider_dict:
            storage_provider_name = storage_provider_dict["type"]
            storage_options = storage_provider_dict.get("options", {})
        else:
            raise ValueError("Missing storage_provider in the config.")

        # Initialize MetadataProvider
        metadata_provider_dict = profile_dict.get("metadata_provider", None)
        metadata_provider = None
        if metadata_provider_dict:
            if metadata_provider_dict["type"] == "manifest":
                metadata_options = metadata_provider_dict.get("options", {})
                # If MetadataProvider has a reference to a different storage provider profile
                storage_provider_profile = metadata_options.pop("storage_provider_profile", None)
                if storage_provider_profile:
                    storage_profile_dict = self._profiles.get(storage_provider_profile)
                    if not storage_profile_dict:
                        raise ValueError(
                            f"Profile '{storage_provider_profile}' referenced by "
                            f"storage_provider_profile does not exist."
                        )

                    # Check if metadata provider is configured for this profile
                    # NOTE: The storage profile for manifests does not support metadata provider (at the moment).
                    local_metadata_provider_dict = storage_profile_dict.get("metadata_provider", None)
                    if local_metadata_provider_dict:
                        raise ValueError(
                            f"Found metadata_provider for profile '{storage_provider_profile}'. "
                            f"This is not supported for storage profiles used by manifests.'"
                        )

                    # Initialize CredentialsProvider
                    local_creds_provider_dict = storage_profile_dict.get("credentials_provider", None)
                    local_creds_provider = self._build_credentials_provider(
                        credentials_provider_dict=local_creds_provider_dict
                    )

                    # Initialize StorageProvider
                    local_storage_provider_dict = storage_profile_dict.get("storage_provider", None)
                    if local_storage_provider_dict:
                        local_name = local_storage_provider_dict["type"]
                        local_storage_options = local_storage_provider_dict.get("options", {})
                    else:
                        raise ValueError("Missing storage_provider in the config.")

                    storage_provider = self._build_storage_provider(
                        local_name, local_storage_options, local_creds_provider
                    )
                else:
                    storage_provider = self._build_storage_provider(
                        storage_provider_name, storage_options, credentials_provider
                    )

                metadata_provider = ManifestMetadataProvider(storage_provider, **metadata_options)
            else:
                class_type = metadata_provider_dict["type"]
                if "." not in class_type:
                    raise ValueError(
                        "Expected a fully qualified class name (e.g., 'module.ClassName'); " f"got '{class_type}'."
                    )
                module_name, class_name = class_type.rsplit(".", 1)
                cls = import_class(class_name, module_name)
                options = metadata_provider_dict.get("options", {})
                metadata_provider = cls(**options)

        return SimpleProviderBundle(
            storage_provider_config=StorageProviderConfig(storage_provider_name, storage_options),
            credentials_provider=credentials_provider,
            metadata_provider=metadata_provider,
        )

    def _build_provider_bundle_from_extension(self, provider_bundle_dict: Dict[str, Any]) -> ProviderBundle:
        class_type = provider_bundle_dict["type"]
        module_name, class_name = class_type.rsplit(".", 1)
        cls = import_class(class_name, module_name)
        options = provider_bundle_dict.get("options", {})
        return cls(**options)

    def _build_provider_bundle(self) -> ProviderBundle:
        if self._provider_bundle:
            return self._provider_bundle  # Return if previously provided.

        # Load 3rd party extension
        provider_bundle_dict = self._profile_dict.get("provider_bundle", None)
        if provider_bundle_dict:
            return self._build_provider_bundle_from_extension(provider_bundle_dict)

        return self._build_provider_bundle_from_config(self._profile_dict)

    def build_config(self) -> "StorageClientConfig":
        bundle = self._build_provider_bundle()
        storage_provider = self._build_storage_provider(
            bundle.storage_provider_config.type, bundle.storage_provider_config.options, bundle.credentials_provider
        )

        # Cache Config
        cache_config: Optional[CacheConfig] = None
        if self._cache_dict is not None:
            tempdir = tempfile.gettempdir()
            default_location = os.path.join(tempdir, ".msc_cache")
            cache_location = self._cache_dict.get("location", default_location)
            size_mb = self._cache_dict.get("size_mb", DEFAULT_CACHE_SIZE_MB)
            os.makedirs(cache_location, exist_ok=True)
            use_etag = self._cache_dict.get("use_etag", False)
            cache_config = CacheConfig(location=cache_location, size_mb=size_mb, use_etag=use_etag)

        # retry options
        retry_config_dict = self._profile_dict.get("retry", None)
        if retry_config_dict:
            attempts = retry_config_dict.get("attempts", DEFAULT_RETRY_ATTEMPTS)
            delay = retry_config_dict.get("delay", DEFAULT_RETRY_DELAY)
            retry_config = RetryConfig(attempts=attempts, delay=delay)
        else:
            retry_config = RetryConfig(attempts=DEFAULT_RETRY_ATTEMPTS, delay=DEFAULT_RETRY_DELAY)

        # set up OpenTelemetry providers once per process
        if self._opentelemetry_dict:
            setup_opentelemetry(self._opentelemetry_dict)

        return StorageClientConfig(
            profile=self._profile,
            storage_provider=storage_provider,
            credentials_provider=bundle.credentials_provider,
            metadata_provider=bundle.metadata_provider,
            cache_config=cache_config,
            retry_config=retry_config,
        )


[docs]class StorageClientConfig: """ Configuration class for the :py:class:`multistorageclient.StorageClient`. """ profile: str storage_provider: StorageProvider credentials_provider: Optional[CredentialsProvider] metadata_provider: Optional[MetadataProvider] cache_config: Optional[CacheConfig] cache_manager: Optional[CacheManager] retry_config: Optional[RetryConfig] _config_dict: Dict[str, Any] def __init__( self, profile: str, storage_provider: StorageProvider, credentials_provider: Optional[CredentialsProvider] = None, metadata_provider: Optional[MetadataProvider] = None, cache_config: Optional[CacheConfig] = None, retry_config: Optional[RetryConfig] = None, ): self.profile = profile self.storage_provider = storage_provider self.credentials_provider = credentials_provider self.metadata_provider = metadata_provider self.cache_config = cache_config self.retry_config = retry_config self.cache_manager = CacheManager(profile, cache_config) if cache_config else None
[docs] @staticmethod def from_json(config_json: str, profile: str = DEFAULT_POSIX_PROFILE_NAME) -> "StorageClientConfig": config_dict = json.loads(config_json) return StorageClientConfig.from_dict(config_dict, profile)
[docs] @staticmethod def from_yaml(config_yaml: str, profile: str = DEFAULT_POSIX_PROFILE_NAME) -> "StorageClientConfig": config_dict = yaml.safe_load(config_yaml) return StorageClientConfig.from_dict(config_dict, profile)
[docs] @staticmethod def from_dict(config_dict: Dict[str, Any], profile: str = DEFAULT_POSIX_PROFILE_NAME) -> "StorageClientConfig": # Validate the config file with predefined JSON schema validate_config(config_dict) # Load config loader = StorageClientConfigLoader(config_dict, profile) config = loader.build_config() config._config_dict = config_dict return config
[docs] @staticmethod def from_file(profile: str = DEFAULT_POSIX_PROFILE_NAME) -> "StorageClientConfig": config_file = os.getenv("MSC_CONFIG", None) # Search config paths if config_file is None: for filename in DEFAULT_CONFIG_FILE_SEARCH_PATHS: if os.path.exists(filename): config_file = filename break if config_file is None: logger.warning( "Cannot find the MSC config file in any of the locations: %s; add a " "config file for sending client-side metrics to an OpenTelemetry service", DEFAULT_CONFIG_FILE_SEARCH_PATHS, ) return StorageClientConfig.from_dict(DEFAULT_POSIX_PROFILE, profile=profile) if not os.path.exists(config_file): raise FileNotFoundError(f"Cannot find the config file at {config_file}") with open(config_file, "r", encoding="utf-8") as f: content = f.read() if config_file.endswith(".json"): return StorageClientConfig.from_json(content, profile) else: return StorageClientConfig.from_yaml(content, profile)
[docs] @staticmethod def from_provider_bundle(config_dict: Dict[str, Any], provider_bundle: ProviderBundle) -> "StorageClientConfig": loader = StorageClientConfigLoader(config_dict, provider_bundle=provider_bundle) return loader.build_config()
def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() del state["credentials_provider"] del state["storage_provider"] del state["metadata_provider"] del state["cache_manager"] return state def __setstate__(self, state: Dict[str, Any]) -> None: loader = StorageClientConfigLoader(state["_config_dict"], state["profile"]) new_config = loader.build_config() self.storage_provider = new_config.storage_provider self.credentials_provider = new_config.credentials_provider self.metadata_provider = new_config.metadata_provider self.cache_config = new_config.cache_config self.retry_config = new_config.retry_config self.cache_manager = new_config.cache_manager