Earth2Studio is now OSS!

Source code for earth2studio.models.auto

# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import os
import re
import warnings
from typing import Any

import aiohttp
import fsspec
import s3fs
from fsspec.callbacks import Callback, TqdmCallback
from fsspec.compression import compr
from fsspec.core import BaseCache, split_protocol
from fsspec.implementations.cached import LocalTempFile, WholeFileCacheFileSystem
from fsspec.implementations.http import HTTPFileSystem
from fsspec.spec import AbstractBufferedFile, AbstractFileSystem
from fsspec.utils import infer_compression
from huggingface_hub import HfFileSystem
from loguru import logger
from tqdm import tqdm

# TODO: Make this package wide? Same as in run.py
logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True)


class CallbackWholeFileCacheFileSystem(WholeFileCacheFileSystem):
    """Extension of Fsspec WholeFileCacheFileSystem to include callback function when
    downloading files to cache (progress bar).

    See: https://github.com/fsspec/filesystem_spec/blob/8be9763e5f895073a9f46c8147aebbc64933e013/fsspec/implementations/cached.py#L651
    """

    def _open(self, path, mode="rb", **kwargs):  # type: ignore
        path = self._strip_protocol(path)
        if "r" not in mode:
            hash = self._mapper(path)
            fn = os.path.join(self.storage[-1], hash)
            user_specified_kwargs = {
                k: v
                for k, v in kwargs.items()
                # those kwargs were added by open(), we don't want them
                if k not in ["autocommit", "block_size", "cache_options"]
            }
            return LocalTempFile(self, path, mode=mode, fn=fn, **user_specified_kwargs)
        detail = self._check_file(path)
        if detail:
            detail, fn = detail
            _, blocks = detail["fn"], detail["blocks"]
            if blocks is True:
                f = open(fn, mode)
                f.original = detail.get("original")
                return f
            else:
                raise ValueError(
                    f"Attempt to open partially cached file {path}"
                    f" as a wholly cached file"
                )
        else:
            fn = self._make_local_details(path)
        kwargs["mode"] = mode

        # call target filesystems open
        self._mkcache()
        if self.compression:
            with self.fs._open(path, **kwargs) as f, open(fn, "wb") as f2:
                if isinstance(f, AbstractBufferedFile):
                    # want no type of caching if just downloading whole thing
                    f.cache = BaseCache(0, f.cache.fetcher, f.size)
                comp = (
                    infer_compression(path)
                    if self.compression == "infer"
                    else self.compression
                )
                f = compr[comp](f, mode="rb")
                data = True
                while data:
                    block = getattr(f, "blocksize", 5 * 2**20)
                    data = f.read(block)
                    f2.write(data)  # type: ignore
        else:
            if "callback" in kwargs:  # Patch here
                self.fs.get_file(path, fn, callback=kwargs["callback"])
            else:
                self.fs.get_file(path, fn)
        self.save_cache()
        return self._open(path, mode)  # type: ignore


class TqdmFormat(tqdm):
    """Provides a `total_time` format parameter. Not used.
    See: https://filesystem-spec.readthedocs.io/en/stable/api.html#fsspec.callbacks.TqdmCallback
    """

    @property
    def format_dict(self) -> dict:
        d = super().format_dict
        return d


class TqdmCallbackRelative(TqdmCallback):
    """Simple extention of Tqdm callback to support progress bar on recurrive gets"""

    def branched(self, path_1, path_2, **kwargs) -> Callback:  # type: ignore
        """Child callback for recursive get"""
        tqdm_kwargs = self._tqdm_kwargs
        tqdm_kwargs["unit"] = "B"
        tqdm_kwargs["unit_scale"] = True
        tqdm_kwargs["unit_divisor"] = 1024
        callback = TqdmCallback(
            tqdm_kwargs=tqdm_kwargs,
            tqdm_cls=self._tqdm_cls,
        )
        return callback


[docs] class Package: """A generic file system abstraction with local caching, uses Fsspec WholeFileCacheFileSystem to manage files. Designed to be used for accessing remote resources, particularly checkpoint files for pre-trained models. Presently supports public folders on NGC, huggingface repos, s3 and any other built in file system Fsspec supports. Parameters ---------- root : str Root directory for file system fs : AbstractFileSystem | None, optional The target filesystem to run underneath. If none is provided a fsspec filesystem will get initialized based on the protocal of the root url, by default None cache : bool, optional Toggle local caching, typically you want this to be true unless the package is a local file system, by default True cache_options : dict, optional Caching options provided to Fsspec. See CachingFileSystem in fsspec for valid options https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.implementations.cached.CachingFileSystem, by default {} """ def __init__( self, root: str, fs: AbstractFileSystem | None = None, cache: bool = True, cache_options: dict = {}, ): self.cache_options = cache_options.copy() if "cache_storage" not in self.cache_options: self.cache_options["cache_storage"] = self.default_cache() if "expiry_time" not in self.cache_options: self.cache_options["expiry_time"] = 31622400 # 1 year self.root = root if fs is not None: self.fs = fs elif root.startswith("ngc://models/"): # Taken from Modulus file utils # Strip ngc model url prefix suffix = "ngc://models/" # The regex check pattern = re.compile(rf"{suffix}[\w-]+(/[\w-]+)?/[\w-]+@[A-Za-z0-9.]+") if not pattern.match(root): raise ValueError( "Invalid URL, should be of form ngc://models/<org_id/team_id/model_id>@<version>\n" + f" got {root}" ) root = root.replace(suffix, "") if len(root.split("@")[0].split("/")) == 3: (org, team, model_version) = root.split("/", 2) (model, version) = model_version.split("@", 1) else: (org, model_version) = root.split("/", 1) (model, version) = model_version.split("@", 1) team = None if team: self.root = f"https://api.ngc.nvidia.com/v2/models/{org}/{team}/{model}/versions/{version}/files/" else: self.root = f"https://api.ngc.nvidia.com/v2/models/{org}/{model}/versions/{version}/files/" self.fs = HTTPFileSystem( block_size=Package.default_blocksize(), client_kwargs={ "timeout": aiohttp.ClientTimeout(total=Package.default_timeout()) }, ) elif root.startswith("hf://"): # https://github.com/huggingface/huggingface_hub/blob/v0.23.4/src/huggingface_hub/hf_file_system.py#L816 if "HF_HUB_DOWNLOAD_TIMEOUT" not in os.environ: os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = str(Package.default_timeout()) self.fs = HfFileSystem( target_options={"default_block_size": Package.default_blocksize()} ) elif root.startswith("s3://"): self.fs = s3fs.S3FileSystem( anon=True, client_kwargs={}, default_block_size=Package.default_blocksize(), ) self.fs.read_timeout = Package.default_timeout() else: protocol = split_protocol(root)[0] self.fs = fsspec.filesystem(protocol) if cache: self.fs = CallbackWholeFileCacheFileSystem(fs=self.fs, **self.cache_options)
[docs] @classmethod def default_cache(cls, path: str = "") -> str: """Default cache location for packages located in `~/.cache/earth2studio` Parameters ---------- path : str, optional Sub-path in cache direction, by default "" Returns ------- str Local cache path """ default_cache = os.path.join(os.path.expanduser("~"), ".cache", "earth2studio") default_cache = os.environ.get("EARTH2STUDIO_CACHE", default_cache) return os.path.join(default_cache, path)
[docs] @classmethod def default_timeout(cls) -> int: """Default remote store timeout in seconds Returns ------- int Time out in seconds """ default_timeout = 300 try: timeout = os.environ.get("EARTH2STUDIO_PACKAGE_TIMEOUT", default_timeout) default_timeout = int(timeout) except ValueError: pass return default_timeout
[docs] @classmethod def default_blocksize(cls) -> int: """Default remote store block size Returns ------- int Download block size in bytes """ return 2**20
@property def cache(self) -> str: """Cache path""" return self.cache_options["cache_storage"]
[docs] def open(self, file_path: str) -> io.BufferedReader: """Open file inside package, caching it to local cache store in the process Parameters ---------- file_path : str Local file path in package directory Returns ------- io.BufferedReader Opened file, can get file path with BufferedReader.name """ full_path = os.path.join(self.root, file_path) filename = os.path.basename(full_path) with TqdmCallbackRelative( tqdm_kwargs={ "desc": "Earth2Studio Package Download", "bar_format": f"Downloading {filename}: " + "{percentage:.0f}%|{bar}{r_bar}", "unit": "B", "unit_scale": True, "unit_divisor": 1024, }, tqdm_cls=TqdmFormat, ) as callback: return self.fs.open(full_path, callback=callback)
[docs] def resolve(self, file_path: str) -> str: """Resolves current relative file path to absolute path inside Package cache Parameters ---------- path : str local path of file in package directory Returns ------- str File path inside cache """ # WARNING: THIS CAN FAIL IF FILE DOES NOT HAVE NAME ATTRIB. BUFFERED FILE TYPES # ARE NOT SUPPORTED HERE. NEED TO LOOK INTO THIS MORE. local_file_path = "" with self.open(file_path) as file: local_file_path = file.name return local_file_path
[docs] def get(self, file_path: str) -> str: """Modulus / backwards compatibility Parameters ---------- path : str local path of file in package directory Returns ------- str File path inside cache """ warnings.warn( "Package.get(path) deprecated. Use Package.resolve(path) instead." ) return self.resolve(file_path)
[docs] class AutoModelMixin: """Abstract class that defines the utils needed auto loading / instantiating models"""
[docs] @classmethod def load_default_package(cls) -> Package: """Loads the default model package Returns ------- Package Model package, file system, object """ raise NotImplementedError("No default package supported")
[docs] @classmethod def load_model( cls, package: Package, ) -> Any: # TODO: Fix types here """Instantiates and loads default model object from provided model package Parameters ---------- package: Package Model package, file system, to fetch assets """ raise NotImplementedError("Load model function not implemented")
@classmethod def from_pretrained(cls, pretrained_model_name_or_path: str | None = None) -> Any: """Loads and instantiates a pre-trained Earth2Studio model Parameters ---------- pretrained_model_name_or_path : str, optional Path to model package (file system). If none is provided, the built in package will be used if provide, by default None. Valid inputs include: - A path to a directory containing model weights saved e.g., ./my_model_directory/. - A path or url/uri to a remote file system supported by Fsspec - A s3 uri supported by s3fs - A NGC model registry uri Returns ------- Union[PrognosticModel, Diagnostic] Instantiated model with loaded checkpoint from loaded model package """ if pretrained_model_name_or_path is None: package = cls.load_default_package() else: package = Package(pretrained_model_name_or_path) return cls.load_model(package)