Source code for multistorageclient.providers.huggingface

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import importlib.util
 17import io
 18import os
 19import tempfile
 20from collections.abc import Callable, Iterator
 21from typing import IO, Any, Optional, TypeVar, Union
 22
 23from huggingface_hub import CommitOperationCopy, HfApi
 24from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError, RepositoryNotFoundError, RevisionNotFoundError
 25from huggingface_hub.hf_api import RepoFile, RepoFolder
 26
 27from ..telemetry import Telemetry
 28from ..types import AWARE_DATETIME_MIN, Credentials, CredentialsProvider, ObjectMetadata, Range, RetryableError
 29from .base import BaseStorageProvider
 30
 31_T = TypeVar("_T")
 32
 33PROVIDER = "huggingface"
 34
 35HF_TRANSFER_UNAVAILABLE_ERROR_MESSAGE = (
 36    "Fast transfer using 'hf_transfer' is enabled (HF_HUB_ENABLE_HF_TRANSFER=1) "
 37    "but 'hf_transfer' package is not available in your environment. "
 38    "Either install hf_transfer with 'pip install hf_transfer' or "
 39    "disable it by setting HF_HUB_ENABLE_HF_TRANSFER=0"
 40)
 41
 42
[docs] 43class HuggingFaceCredentialsProvider(CredentialsProvider): 44 """ 45 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides HuggingFace credentials. 46 """ 47 48 def __init__(self, access_token: str): 49 """ 50 Initializes the :py:class:`HuggingFaceCredentialsProvider` with the provided access token. 51 52 :param access_token: The HuggingFace access token for authentication. 53 """ 54 self.token = access_token 55
[docs] 56 def get_credentials(self) -> Credentials: 57 """ 58 Retrieves the current HuggingFace credentials. 59 60 :return: The current credentials used for HuggingFace authentication. 61 """ 62 return Credentials( 63 access_key="", 64 secret_key="", 65 token=self.token, 66 expiration=None, 67 )
68
[docs] 69 def refresh_credentials(self) -> None: 70 """ 71 Refreshes the credentials if they are expired or about to expire. 72 73 Note: HuggingFace tokens typically don't expire, so this is a no-op. 74 """ 75 pass
76 77
[docs] 78class HuggingFaceStorageProvider(BaseStorageProvider): 79 """ 80 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with HuggingFace Hub repositories. 81 """ 82 83 def __init__( 84 self, 85 repository_id: str, 86 repo_type: str = "model", 87 base_path: str = "", 88 repo_revision: str = "main", 89 credentials_provider: Optional[CredentialsProvider] = None, 90 config_dict: Optional[dict[str, Any]] = None, 91 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 92 ): 93 """ 94 Initializes the :py:class:`HuggingFaceStorageProvider` with repository information and optional credentials provider. 95 96 :param repository_id: The HuggingFace repository ID (e.g., 'username/repo-name'). 97 :param repo_type: The type of repository ('dataset', 'model', 'space'). Defaults to 'model'. 98 :param base_path: The root prefix path within the repository where all operations will be scoped. 99 :param repo_revision: The git revision (branch, tag, or commit) to use. Defaults to 'main'. 100 :param credentials_provider: The provider to retrieve HuggingFace credentials. 101 :param config_dict: Resolved MSC config. 102 :param telemetry_provider: A function that provides a telemetry instance. 103 """ 104 105 # Validate repo_type 106 allowed_repo_types = {"dataset", "model", "space"} 107 if repo_type not in allowed_repo_types: 108 raise ValueError(f"Invalid repo_type '{repo_type}'. Must be one of: {allowed_repo_types}") 109 110 # Validate repository_id format 111 if not repository_id or "/" not in repository_id: 112 raise ValueError(f"Invalid repository_id '{repository_id}'. Expected format: 'username/repo-name'") 113 114 self._validate_hf_transfer_availability() 115 116 super().__init__( 117 base_path=base_path, 118 provider_name=PROVIDER, 119 config_dict=config_dict, 120 telemetry_provider=telemetry_provider, 121 ) 122 123 self._repository_id = repository_id 124 self._repo_type = repo_type 125 self._repo_revision = repo_revision 126 self._credentials_provider = credentials_provider 127 128 self._hf_client: HfApi = self._create_hf_api_client() 129 130 def _create_hf_api_client(self) -> HfApi: 131 """ 132 Creates and configures the HuggingFace API client. 133 134 Initializes the HfApi client with authentication token if credentials are provided, 135 otherwise creates an unauthenticated client for public repositories. 136 137 :return: Configured HfApi client instance. 138 """ 139 140 token = None 141 if self._credentials_provider: 142 creds = self._credentials_provider.get_credentials() 143 token = creds.token 144 145 return HfApi(token=token) 146 147 def _validate_hf_transfer_availability(self) -> None: 148 """ 149 Validates that hf_transfer is available if it's enabled via environment variables. 150 151 Raises: 152 ValueError: If hf_transfer is enabled but not available. 153 """ 154 # Check if hf_transfer is enabled via environment variable 155 hf_transfer_enabled = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "").lower() in ("1", "on", "true", "yes") 156 157 if hf_transfer_enabled and importlib.util.find_spec("hf_transfer") is None: 158 raise ValueError(HF_TRANSFER_UNAVAILABLE_ERROR_MESSAGE) 159 160 def _parse_rate_limit_headers(self, response) -> str: 161 """ 162 Parses HuggingFace rate limit headers and returns formatted information. 163 164 HuggingFace returns rate limit information in these headers: 165 - RateLimit: "api";r=0;t=142 166 - r = requests remaining in the current window 167 - t = seconds until rate limit resets 168 - RateLimit-Policy: "fixed window";"api";q=10000;w=300 169 - q = total requests allowed per window 170 - w = window size in seconds 171 172 Reference: https://huggingface.co/docs/hub/rate-limits 173 174 :param response: The HTTP response object containing rate limit headers. 175 :return: Formatted string with rate limit information, or empty string if headers not found. 176 """ 177 178 try: 179 headers = response.headers 180 except Exception: 181 return "" 182 183 rate_limit_info = [] 184 185 # Note: HTTP headers are case-insensitive, but we use the canonical casing from HF docs 186 if "RateLimit" in headers: 187 rate_limit = headers["RateLimit"] 188 # Extract r (remaining) and t (time until reset) 189 remaining = None 190 reset_seconds = None 191 192 parts = rate_limit.split(";") 193 for part in parts: 194 part = part.strip() 195 if part.startswith("r="): 196 try: 197 remaining = int(part[2:]) 198 except ValueError: 199 pass 200 elif part.startswith("t="): 201 try: 202 reset_seconds = int(part[2:]) 203 except ValueError: 204 pass 205 206 if remaining is not None: 207 rate_limit_info.append(f"Requests remaining in current window: {remaining}") 208 if reset_seconds is not None: 209 rate_limit_info.append(f"Rate limit resets in: {reset_seconds} seconds") 210 211 if "RateLimit-Policy" in headers: 212 policy = headers["RateLimit-Policy"] 213 # Extract q (quota) and w (window size) 214 quota = None 215 window_seconds = None 216 217 parts = policy.split(";") 218 for part in parts: 219 part = part.strip() 220 if part.startswith("q="): 221 try: 222 quota = int(part[2:]) 223 except ValueError: 224 pass 225 elif part.startswith("w="): 226 try: 227 window_seconds = int(part[2:]) 228 except ValueError: 229 pass 230 231 if quota is not None and window_seconds is not None: 232 window_minutes = window_seconds / 60 233 rate_limit_info.append(f"Rate limit policy: {quota} requests per {window_minutes:.0f}-minute window") 234 235 if rate_limit_info: 236 return " | ".join(rate_limit_info) 237 238 return "" 239 240 def _translate_errors( 241 self, 242 func: Callable[[], _T], 243 operation: str, 244 repo_id: str, 245 path: str, 246 ) -> _T: 247 """ 248 Translates HuggingFace errors into standardized exceptions with retry logic. 249 250 Parses HuggingFace rate limit headers (RateLimit and RateLimit-Policy) to provide 251 detailed information about rate limiting to users. See https://huggingface.co/docs/hub/rate-limits 252 253 :param func: The function that performs the actual HuggingFace operation. 254 :param operation: The type of operation being performed (e.g., "upload", "download", "delete"). 255 :param repo_id: The HuggingFace repository ID. 256 :param path: The path of the object within the repository. 257 :return: The result of the HuggingFace operation. 258 :raises RetryableError: For transient errors that can be retried (429, 503, connection errors). 259 :raises FileNotFoundError: When the requested resource is not found. 260 :raises RuntimeError: For other non-retryable errors. 261 """ 262 try: 263 return func() 264 except RepositoryNotFoundError as error: 265 raise FileNotFoundError( 266 f"Repository not found or access denied: {repo_id}. " 267 f"Verify the repository exists and you have access permissions." 268 ) from error 269 except RevisionNotFoundError as error: 270 raise FileNotFoundError( 271 f"Revision '{self._repo_revision}' not found in repository {repo_id}. " 272 f"Verify the branch, tag, or commit exists." 273 ) from error 274 except EntryNotFoundError as error: 275 raise FileNotFoundError(f"File not found in HuggingFace repository: {path}") from error 276 except FileNotFoundError: 277 raise 278 except HfHubHTTPError as error: 279 # Extract status code and parse rate limit headers 280 # Don't use hasattr() - it's unreliable with response objects 281 status_code = None 282 response = None 283 284 try: 285 response = error.response 286 if response is not None: 287 status_code = response.status_code 288 except AttributeError: 289 pass 290 291 rate_limit_info = self._parse_rate_limit_headers(response) 292 quota_suffix = f" | {rate_limit_info}" if rate_limit_info else "" 293 294 error_info = f"repo_id: {repo_id}, path: {path}, status_code: {status_code}, error: {error}" 295 296 if status_code == 404: 297 raise FileNotFoundError(f"Object {repo_id}/{path} does not exist. {error_info}") from error 298 elif status_code == 409: 299 raise RetryableError(f"Conflict Error for {repo_id}. {error_info}{quota_suffix}") from error 300 elif status_code == 429: 301 base_message = f"Rate limit exceeded when {operation} object(s) at {repo_id}/{path}. {error_info}" 302 raise RetryableError(f"{base_message}{quota_suffix}") from error 303 elif status_code == 503: 304 raise RetryableError( 305 f"Service unavailable when {operation} object(s) at {repo_id}/{path}. {error_info}{quota_suffix}" 306 ) from error 307 elif status_code in (408, 500, 502, 504): 308 raise RetryableError( 309 f"Transient error ({status_code}) when {operation} object(s) at {repo_id}/{path}. {error_info}{quota_suffix}" 310 ) from error 311 else: 312 raise RuntimeError( 313 f"HuggingFace API error during {operation} of {path}: {error}{quota_suffix}" 314 ) from error 315 except (ConnectionError, TimeoutError, OSError) as error: 316 raise RetryableError( 317 f"Connection error when {operation} object(s) at {repo_id}/{path}, error type: {type(error).__name__}" 318 ) from error 319 except Exception as error: 320 raise RuntimeError(f"Unexpected error during {operation} of {path}: {error}") from error 321 322 def _put_object( 323 self, 324 path: str, 325 body: bytes, 326 if_match: Optional[str] = None, 327 if_none_match: Optional[str] = None, 328 attributes: Optional[dict[str, str]] = None, 329 ) -> int: 330 """ 331 Uploads an object to the HuggingFace repository. 332 333 :param path: The path where the object will be stored in the repository. 334 :param body: The content of the object to store. 335 :param if_match: Optional ETag for conditional uploads (not supported by HuggingFace). 336 :param if_none_match: Optional ETag for conditional uploads (not supported by HuggingFace). 337 :param attributes: Optional attributes for the object (not supported by HuggingFace). 338 :return: Data size in bytes. 339 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur. 340 :raises ValueError: If client attempts to create a directory. 341 :raises ValueError: If conditional upload parameters are provided (not supported). 342 """ 343 if not self._hf_client: 344 raise RuntimeError("HuggingFace client not initialized") 345 346 if if_match is not None or if_none_match is not None: 347 raise ValueError( 348 "HuggingFace provider does not support conditional uploads. " 349 "if_match and if_none_match parameters are not supported." 350 ) 351 352 if attributes is not None: 353 raise ValueError( 354 "HuggingFace provider does not support custom object attributes. " 355 "Use commit messages or repository metadata instead." 356 ) 357 358 if path.endswith("/"): 359 raise ValueError( 360 "HuggingFace Storage Provider does not support explicit directory creation. " 361 "Directories are created implicitly when files are uploaded to paths within them." 362 ) 363 364 path = self._normalize_path(path) 365 366 def _invoke_api(): 367 with tempfile.NamedTemporaryFile(delete=False) as temp_file: 368 temp_file.write(body) 369 temp_file_path = temp_file.name 370 371 try: 372 self._hf_client.upload_file( 373 path_or_fileobj=temp_file_path, 374 path_in_repo=path, 375 repo_id=self._repository_id, 376 repo_type=self._repo_type, 377 revision=self._repo_revision, 378 commit_message=f"Upload {path}", 379 commit_description=None, 380 create_pr=False, 381 ) 382 383 return len(body) 384 385 finally: 386 os.unlink(temp_file_path) 387 388 return self._translate_errors(_invoke_api, "PUT", self._repository_id, path) 389 390 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 391 """ 392 Retrieves an object from the HuggingFace repository. 393 394 :param path: The path of the object to retrieve from the repository. 395 :param byte_range: Optional byte range for partial content (not supported by HuggingFace). 396 :return: The content of the retrieved object. 397 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur. 398 :raises ValueError: If a byte range is requested (HuggingFace doesn't support range reads). 399 :raises FileNotFoundError: If the file doesn't exist in the repository. 400 """ 401 402 if not self._hf_client: 403 raise RuntimeError("HuggingFace client not initialized") 404 405 if byte_range is not None: 406 raise ValueError( 407 "HuggingFace provider does not support partial range reads. " 408 f"Requested range: offset={byte_range.offset}, size={byte_range.size}. " 409 "To read the entire file, call get_object() without the byte_range parameter." 410 ) 411 412 path = self._normalize_path(path) 413 414 def _invoke_api(): 415 with tempfile.TemporaryDirectory() as temp_dir: 416 downloaded_path = self._hf_client.hf_hub_download( 417 repo_id=self._repository_id, 418 filename=path, 419 repo_type=self._repo_type, 420 revision=self._repo_revision, 421 local_dir=temp_dir, 422 ) 423 424 with open(downloaded_path, "rb") as f: 425 data = f.read() 426 427 return data 428 429 return self._translate_errors(_invoke_api, "GET", self._repository_id, path) 430 431 def _copy_object(self, src_path: str, dest_path: str) -> int: 432 """ 433 Copies an object within the HuggingFace repository using server-side copy. 434 435 .. note:: 436 Copy behavior is size-dependent: files ≥10MB are copied remotely via 437 metadata (LFS), while files <10MB are downloaded and re-uploaded. 438 439 :param src_path: The source path of the object to copy. 440 :param dest_path: The destination path for the copied object. 441 :return: Data size in bytes. 442 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur. 443 :raises FileNotFoundError: If the source file doesn't exist. 444 """ 445 if not self._hf_client: 446 raise RuntimeError("HuggingFace client not initialized") 447 448 src_path = self._normalize_path(src_path) 449 dest_path = self._normalize_path(dest_path) 450 451 src_object = self._get_object_metadata(src_path) 452 453 def _invoke_api(): 454 operations = [ 455 CommitOperationCopy( 456 src_path_in_repo=src_path, 457 path_in_repo=dest_path, 458 ) 459 ] 460 461 self._hf_client.create_commit( 462 repo_id=self._repository_id, 463 operations=operations, 464 commit_message=f"Copy {src_path} to {dest_path}", 465 repo_type=self._repo_type, 466 revision=self._repo_revision, 467 ) 468 469 return src_object.content_length 470 471 return self._translate_errors(_invoke_api, "COPY", self._repository_id, f"{src_path} to {dest_path}") 472 473 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 474 """ 475 Deletes an object from the HuggingFace repository. 476 477 :param path: The path of the object to delete from the repository. 478 :param if_match: Optional ETag for conditional deletion (not supported by HuggingFace). 479 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur. 480 :raises ValueError: If conditional deletion parameters are provided (not supported). 481 :raises FileNotFoundError: If the file doesn't exist in the repository. 482 """ 483 if not self._hf_client: 484 raise RuntimeError("HuggingFace client not initialized") 485 486 if if_match is not None: 487 raise ValueError( 488 "HuggingFace provider does not support conditional deletion. if_match parameter is not supported." 489 ) 490 491 path = self._normalize_path(path) 492 493 def _invoke_api(): 494 self._hf_client.delete_file( 495 path_in_repo=path, 496 repo_id=self._repository_id, 497 repo_type=self._repo_type, 498 revision=self._repo_revision, 499 commit_message=f"Delete {path}", 500 ) 501 502 self._translate_errors(_invoke_api, "DELETE", self._repository_id, path) 503 504 def _item_to_metadata(self, item: Union[RepoFile, RepoFolder]) -> ObjectMetadata: 505 """ 506 Convert a RepoFile or RepoFolder into ObjectMetadata. 507 508 :param item: The RepoFile or RepoFolder item from HuggingFace API. 509 :return: ObjectMetadata representing the item. 510 """ 511 last_modified = AWARE_DATETIME_MIN 512 513 if isinstance(item, RepoFile): 514 etag = item.blob_id 515 return ObjectMetadata( 516 key=item.path, 517 type="file", 518 content_length=item.size, 519 last_modified=last_modified, 520 etag=etag, 521 content_type=None, 522 storage_class=None, 523 metadata=None, 524 ) 525 else: 526 etag = item.tree_id 527 return ObjectMetadata( 528 key=item.path, 529 type="directory", 530 content_length=0, 531 last_modified=last_modified, 532 etag=etag, 533 content_type=None, 534 storage_class=None, 535 metadata=None, 536 ) 537 538 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 539 """ 540 Retrieves metadata for an object in the HuggingFace repository. 541 542 :param path: The path of the object to get metadata for. 543 :param strict: Whether to raise an error if the object doesn't exist. 544 :return: Metadata about the object. 545 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur. 546 :raises FileNotFoundError: If the file doesn't exist and strict=True. 547 """ 548 if not self._hf_client: 549 raise RuntimeError("HuggingFace client not initialized") 550 551 path = self._normalize_path(path) 552 553 def _invoke_api(): 554 items = self._hf_client.get_paths_info( 555 repo_id=self._repository_id, 556 paths=[path], 557 repo_type=self._repo_type, 558 revision=self._repo_revision, 559 expand=True, 560 ) 561 562 if not items: 563 raise FileNotFoundError(f"File not found in HuggingFace repository: {path}") 564 565 item = items[0] 566 return self._item_to_metadata(item) 567 568 try: 569 return self._translate_errors(_invoke_api, "HEAD", self._repository_id, path) 570 except FileNotFoundError as error: 571 if strict: 572 dir_path = path.rstrip("/") + "/" 573 if self._is_dir(dir_path): 574 return ObjectMetadata( 575 key=dir_path, 576 type="directory", 577 content_length=0, 578 last_modified=AWARE_DATETIME_MIN, 579 etag=None, 580 content_type=None, 581 storage_class=None, 582 metadata=None, 583 ) 584 raise error 585 586 def _list_objects( 587 self, 588 path: str, 589 start_after: Optional[str] = None, 590 end_at: Optional[str] = None, 591 include_directories: bool = False, 592 follow_symlinks: bool = True, 593 ) -> Iterator[ObjectMetadata]: 594 """ 595 Lists objects in the HuggingFace repository under the specified path. 596 597 :param path: The path to list objects under. 598 :param start_after: The key to start listing after (exclusive, used as cursor). 599 :param end_at: The key to end listing at (inclusive, used as cursor). 600 :param include_directories: Whether to include directories in the listing. 601 :return: An iterator over object metadata for objects under the specified path. 602 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur. 603 604 .. note:: 605 HuggingFace Hub API does not natively support pagination parameters. 606 This implementation fetches all items and uses cursor-based filtering, 607 which may impact performance for large repositories. The ordering is 608 directory-first, then files, with lexicographical ordering within each group. 609 """ 610 if not self._hf_client: 611 raise RuntimeError("HuggingFace client not initialized") 612 613 path = self._normalize_path(path) 614 615 try: 616 metadata = self._get_object_metadata(path.rstrip("/"), strict=False) 617 if metadata and metadata.type == "file": 618 yield metadata 619 return 620 except FileNotFoundError: 621 pass 622 623 def _invoke_api(): 624 dir_path = path.rstrip("/") 625 626 repo_items = self._hf_client.list_repo_tree( 627 repo_id=self._repository_id, 628 path_in_repo=dir_path + "/" if dir_path else None, 629 repo_type=self._repo_type, 630 revision=self._repo_revision, 631 expand=True, 632 recursive=not include_directories, 633 ) 634 635 return list(repo_items) 636 637 try: 638 items = self._translate_errors(_invoke_api, "LIST", self._repository_id, path) 639 640 # Use cursor-based pagination because HuggingFace returns items with 641 # directory-first ordering (not pure lexicographical). 642 seen_start = start_after is None 643 seen_end = False 644 645 for item in items: 646 if seen_end: 647 break 648 649 metadata = self._item_to_metadata(item) 650 key = metadata.key 651 652 if not seen_start: 653 if key == start_after: 654 seen_start = True 655 continue 656 657 should_yield = False 658 if include_directories and isinstance(item, RepoFolder): 659 should_yield = True 660 elif isinstance(item, RepoFile): 661 should_yield = True 662 663 if should_yield: 664 yield metadata 665 666 if end_at is not None and key == end_at: 667 seen_end = True 668 669 except FileNotFoundError: 670 # Directory doesn't exist - return empty (matches POSIX behavior) 671 pass 672 673 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 674 """ 675 Uploads a file to the HuggingFace repository. 676 677 :param remote_path: The remote path where the file will be stored in the repository. 678 :param f: File path or file object to upload. 679 :param attributes: Optional attributes for the file (not supported by HuggingFace). 680 :return: Data size in bytes. 681 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur. 682 :raises ValueError: If client attempts to create a directory. 683 :raises ValueError: If custom attributes are provided (not supported). 684 """ 685 if not self._hf_client: 686 raise RuntimeError("HuggingFace client not initialized") 687 688 if attributes is not None: 689 raise ValueError( 690 "HuggingFace provider does not support custom file attributes. " 691 "Use commit messages or repository metadata instead." 692 ) 693 694 if remote_path.endswith("/"): 695 raise ValueError( 696 "HuggingFace Storage Provider does not support explicit directory creation. " 697 "Directories are created implicitly when files are uploaded to paths within them." 698 ) 699 700 remote_path = self._normalize_path(remote_path) 701 702 def _invoke_api(): 703 if isinstance(f, str): 704 file_size = os.path.getsize(f) 705 706 self._hf_client.upload_file( 707 path_or_fileobj=f, 708 path_in_repo=remote_path, 709 repo_id=self._repository_id, 710 repo_type=self._repo_type, 711 revision=self._repo_revision, 712 commit_message=f"Upload {remote_path}", 713 commit_description=None, 714 create_pr=False, 715 ) 716 717 return file_size 718 719 else: 720 content = f.read() 721 722 if isinstance(content, str): 723 content_bytes = content.encode("utf-8") 724 else: 725 content_bytes = content 726 727 # Create temporary file since HfAPI.upload_file requires BinaryIO, not generic IO 728 with tempfile.NamedTemporaryFile(delete=False) as temp_file: 729 temp_file.write(content_bytes) 730 temp_file_path = temp_file.name 731 732 try: 733 self._hf_client.upload_file( 734 path_or_fileobj=temp_file_path, 735 path_in_repo=remote_path, 736 repo_id=self._repository_id, 737 repo_type=self._repo_type, 738 revision=self._repo_revision, 739 commit_message=f"Upload {remote_path}", 740 create_pr=False, 741 ) 742 743 return len(content_bytes) 744 745 finally: 746 os.unlink(temp_file_path) 747 748 return self._translate_errors(_invoke_api, "PUT", self._repository_id, remote_path) 749 750 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 751 """ 752 Downloads a file from the HuggingFace repository. 753 754 :param remote_path: The remote path of the file to download from the repository. 755 :param f: Local file path or file object to write to. 756 :param metadata: Optional object metadata (not used in this implementation). 757 :return: Data size in bytes. 758 """ 759 if not self._hf_client: 760 raise RuntimeError("HuggingFace client not initialized") 761 762 remote_path = self._normalize_path(remote_path) 763 764 def _invoke_api(): 765 if isinstance(f, str): 766 parent_dir = os.path.dirname(f) 767 if parent_dir: 768 os.makedirs(parent_dir, exist_ok=True) 769 770 target_dir = parent_dir if parent_dir else "." 771 downloaded_path = self._hf_client.hf_hub_download( 772 repo_id=self._repository_id, 773 filename=remote_path, 774 repo_type=self._repo_type, 775 revision=self._repo_revision, 776 local_dir=target_dir, 777 ) 778 779 if os.path.abspath(downloaded_path) != os.path.abspath(f): 780 os.rename(downloaded_path, f) 781 782 return os.path.getsize(f) 783 784 else: 785 with tempfile.TemporaryDirectory() as temp_dir: 786 downloaded_path = self._hf_client.hf_hub_download( 787 repo_id=self._repository_id, 788 filename=remote_path, 789 repo_type=self._repo_type, 790 revision=self._repo_revision, 791 local_dir=temp_dir, 792 ) 793 794 with open(downloaded_path, "rb") as src: 795 data = src.read() 796 if isinstance(f, io.TextIOBase): 797 f.write(data.decode("utf-8")) 798 else: 799 f.write(data) 800 801 return len(data) 802 803 return self._translate_errors(_invoke_api, "GET", self._repository_id, remote_path) 804 805 def _is_dir(self, path: str) -> bool: 806 """ 807 Helper method to check if a path is a directory. 808 809 :param path: The path to check. 810 :return: True if the path appears to be a directory (has files under it). 811 """ 812 path = path.rstrip("/") 813 if not path: 814 # The root of the repo is always a directory 815 return True 816 817 try: 818 path_info = self._hf_client.get_paths_info( 819 repo_id=self._repository_id, 820 paths=[path], 821 repo_type=self._repo_type, 822 revision=self._repo_revision, 823 ) 824 825 if not path_info: 826 return False 827 828 return isinstance(path_info[0], RepoFolder) 829 830 except RepositoryNotFoundError as e: 831 raise FileNotFoundError( 832 f"Repository not found or access denied: {self._repository_id}. " 833 f"Verify the repository exists and you have access permissions." 834 ) from e 835 except RevisionNotFoundError as e: 836 raise FileNotFoundError( 837 f"Revision '{self._repo_revision}' not found in repository {self._repository_id}. " 838 f"Verify the branch, tag, or commit exists." 839 ) from e 840 except IndexError: 841 return False 842 except Exception as e: 843 raise Exception(f"Unexpected error: {e}") 844 845 def _normalize_path(self, path: str) -> str: 846 """ 847 Normalize path for HuggingFace API by removing leading slashes. 848 HuggingFace expects relative paths within repositories. 849 """ 850 return path.lstrip("/")