Source code for earth2studio.models.auto.package

# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import os
import re
import warnings

import aiohttp
import fsspec
import s3fs
from fsspec.callbacks import Callback, TqdmCallback
from fsspec.compression import compr
from fsspec.core import BaseCache, split_protocol
from fsspec.implementations.cached import LocalTempFile, WholeFileCacheFileSystem
from fsspec.spec import AbstractBufferedFile, AbstractFileSystem
from fsspec.utils import infer_compression
from huggingface_hub import HfFileSystem
from loguru import logger
from tqdm import tqdm

from earth2studio.models.auto.ngc import NGCModelFileSystem

# TODO: Make this package wide? Same as in run.py
logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True)


class CallbackWholeFileCacheFileSystem(WholeFileCacheFileSystem):
    """Extension of Fsspec WholeFileCacheFileSystem to include callback function when
    downloading files to cache (progress bar).

    See: https://github.com/fsspec/filesystem_spec/blob/8be9763e5f895073a9f46c8147aebbc64933e013/fsspec/implementations/cached.py#L651
    """

    def _open(self, path, mode="rb", **kwargs):  # type: ignore
        path = self._strip_protocol(path)
        if "r" not in mode:
            hash = self._mapper(path)
            fn = os.path.join(self.storage[-1], hash)
            user_specified_kwargs = {
                k: v
                for k, v in kwargs.items()
                # those kwargs were added by open(), we don't want them
                if k not in ["autocommit", "block_size", "cache_options"]
            }
            return LocalTempFile(self, path, mode=mode, fn=fn, **user_specified_kwargs)
        detail = self._check_file(path)
        if detail:
            detail, fn = detail
            _, blocks = detail["fn"], detail["blocks"]
            if blocks is True:
                f = open(fn, mode)
                f.original = detail.get("original")
                return f
            else:
                raise ValueError(
                    f"Attempt to open partially cached file {path}"
                    f" as a wholly cached file"
                )
        else:
            fn = self._make_local_details(path)
        kwargs["mode"] = mode

        # call target filesystems open
        self._mkcache()
        if self.compression:
            with self.fs._open(path, **kwargs) as f, open(fn, "wb") as f2:
                if isinstance(f, AbstractBufferedFile):
                    # want no type of caching if just downloading whole thing
                    f.cache = BaseCache(0, f.cache.fetcher, f.size)
                comp = (
                    infer_compression(path)
                    if self.compression == "infer"
                    else self.compression
                )
                f = compr[comp](f, mode="rb")
                data = True
                while data:
                    block = getattr(f, "blocksize", 5 * 2**20)
                    data = f.read(block)
                    f2.write(data)  # type: ignore
        else:
            if "callback" in kwargs:  # Patch here
                self.fs.get_file(path, fn, callback=kwargs["callback"])
            else:
                self.fs.get_file(path, fn)
        self.save_cache()
        return self._open(path, mode)  # type: ignore


class TqdmFormat(tqdm):
    """Provides a `total_time` format parameter. Not used.
    See: https://filesystem-spec.readthedocs.io/en/stable/api.html#fsspec.callbacks.TqdmCallback
    """

    @property
    def format_dict(self) -> dict:
        d = super().format_dict
        return d


class TqdmCallbackRelative(TqdmCallback):
    """Simple extention of Tqdm callback to support progress bar on recurrive gets"""

    def branched(self, path_1, path_2, **kwargs) -> Callback:  # type: ignore
        """Child callback for recursive get"""
        tqdm_kwargs = self._tqdm_kwargs
        tqdm_kwargs["unit"] = "B"
        tqdm_kwargs["unit_scale"] = True
        tqdm_kwargs["unit_divisor"] = 1024
        callback = TqdmCallback(
            tqdm_kwargs=tqdm_kwargs,
            tqdm_cls=self._tqdm_cls,
        )
        return callback


[docs] class Package: """A generic file system abstraction with local caching, uses Fsspec WholeFileCacheFileSystem to manage files. Designed to be used for accessing remote resources, particularly checkpoint files for pre-trained models. Presently supports public folders on NGC, huggingface repos, s3 and any other built in file system Fsspec supports. Parameters ---------- root : str Root directory for file system fs : AbstractFileSystem | None, optional The target filesystem to run underneath. If none is provided a fsspec filesystem will get initialized based on the protocal of the root url, by default None cache : bool, optional Toggle local caching, typically you want this to be true unless the package is a local file system, by default True cache_options : dict, optional Caching options provided to Fsspec. See CachingFileSystem in fsspec for valid options https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.implementations.cached.CachingFileSystem, by default {} """ def __init__( self, root: str, fs: AbstractFileSystem | None = None, cache: bool = True, cache_options: dict = {}, ): self.cache_options = cache_options.copy() if "cache_storage" not in self.cache_options: self.cache_options["cache_storage"] = self.default_cache() if "expiry_time" not in self.cache_options: self.cache_options["expiry_time"] = 31622400 # 1 year self.root = root if fs is not None: self.fs = fs elif root.startswith("ngc://models/"): # Taken from PhysicsNeMo file utils # Strip ngc model url prefix suffix = "ngc://models/" # The regex check pattern = re.compile(rf"{suffix}[\w-]+(/[\w-]+)?/[\w-]+@[A-Za-z0-9.]+") if not pattern.match(root): raise ValueError( "Invalid URL, should be of form ngc://models/<org_id/team_id/model_id>@<version>\n" + f" got {root}" ) self.root = root self.fs = NGCModelFileSystem( # type: ignore block_size=Package.default_blocksize(), client_kwargs={ "timeout": aiohttp.ClientTimeout(total=Package.default_timeout()) }, ) elif root.startswith("hf://"): # https://github.com/huggingface/huggingface_hub/blob/v0.23.4/src/huggingface_hub/hf_file_system.py#L816 if "HF_HUB_DOWNLOAD_TIMEOUT" not in os.environ: os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = str(Package.default_timeout()) self.fs = HfFileSystem( target_options={"default_block_size": Package.default_blocksize()} ) elif root.startswith("s3://"): self.fs = s3fs.S3FileSystem( anon=True, client_kwargs={}, default_block_size=Package.default_blocksize(), ) self.fs.read_timeout = Package.default_timeout() else: protocol = split_protocol(root)[0] self.fs = fsspec.filesystem(protocol) if cache: self.fs = CallbackWholeFileCacheFileSystem(fs=self.fs, **self.cache_options)
[docs] @classmethod def default_cache(cls, path: str = "") -> str: """Default cache location for packages located in `~/.cache/earth2studio` Parameters ---------- path : str, optional Sub-path in cache direction, by default "" Returns ------- str Local cache path """ default_cache = os.path.join(os.path.expanduser("~"), ".cache", "earth2studio") default_cache = os.environ.get("EARTH2STUDIO_CACHE", default_cache) return os.path.join(default_cache, path)
[docs] @classmethod def default_timeout(cls) -> int: """Default remote store timeout in seconds Returns ------- int Time out in seconds """ default_timeout = 300 try: timeout = os.environ.get("EARTH2STUDIO_PACKAGE_TIMEOUT", default_timeout) default_timeout = int(timeout) except ValueError: pass return default_timeout
[docs] @classmethod def default_blocksize(cls) -> int: """Default remote store block size Returns ------- int Download block size in bytes """ return 2**20
@property def cache(self) -> str: """Cache path""" return self.cache_options["cache_storage"]
[docs] def open(self, file_path: str) -> io.BufferedReader: """Open file inside package, caching it to local cache store in the process Parameters ---------- file_path : str Local file path in package directory Returns ------- io.BufferedReader Opened file, can get file path with BufferedReader.name """ full_path = os.path.join(self.root, file_path) filename = os.path.basename(full_path) with TqdmCallbackRelative( tqdm_kwargs={ "desc": "Earth2Studio Package Download", "bar_format": f"Downloading {filename}: " + "{percentage:.0f}%|{bar}{r_bar}", "unit": "B", "unit_scale": True, "unit_divisor": 1024, }, tqdm_cls=TqdmFormat, ) as callback: try: return self.fs.open(full_path, callback=callback) except fsspec.exceptions.FSTimeoutError as e: logger.error( f"Package fetch timeout. Consider increasing timeout through environment variable 'EARTH2STUDIO_PACKAGE_TIMEOUT'. Currently {self.default_timeout()}s." ) raise e
[docs] def resolve(self, file_path: str) -> str: """Resolves current relative file path to absolute path inside Package cache Parameters ---------- path : str local path of file in package directory Returns ------- str File path inside cache """ # WARNING: THIS CAN FAIL IF FILE DOES NOT HAVE NAME ATTRIB. BUFFERED FILE TYPES # ARE NOT SUPPORTED HERE. NEED TO LOOK INTO THIS MORE. local_file_path = "" with self.open(file_path) as file: local_file_path = file.name return local_file_path
[docs] def get(self, file_path: str) -> str: """PhysicsNeMo / backwards compatibility Parameters ---------- path : str local path of file in package directory Returns ------- str File path inside cache """ warnings.warn( "Package.get(path) deprecated. Use Package.resolve(path) instead." ) return self.resolve(file_path)