1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import importlib.util
17import io
18import os
19import tempfile
20from collections.abc import Callable, Iterator
21from typing import IO, Any, Optional, Union
22
23from huggingface_hub import CommitOperationCopy, HfApi
24from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError, RepositoryNotFoundError, RevisionNotFoundError
25from huggingface_hub.hf_api import RepoFile, RepoFolder
26
27from ..telemetry import Telemetry
28from ..types import AWARE_DATETIME_MIN, Credentials, CredentialsProvider, ObjectMetadata, Range
29from .base import BaseStorageProvider
30
31PROVIDER = "huggingface"
32
33HF_TRANSFER_UNAVAILABLE_ERROR_MESSAGE = (
34 "Fast transfer using 'hf_transfer' is enabled (HF_HUB_ENABLE_HF_TRANSFER=1) "
35 "but 'hf_transfer' package is not available in your environment. "
36 "Either install hf_transfer with 'pip install hf_transfer' or "
37 "disable it by setting HF_HUB_ENABLE_HF_TRANSFER=0"
38)
39
40
[docs]
41class HuggingFaceCredentialsProvider(CredentialsProvider):
42 """
43 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides HuggingFace credentials.
44 """
45
46 def __init__(self, access_token: str):
47 """
48 Initializes the :py:class:`HuggingFaceCredentialsProvider` with the provided access token.
49
50 :param access_token: The HuggingFace access token for authentication.
51 """
52 self.token = access_token
53
[docs]
54 def get_credentials(self) -> Credentials:
55 """
56 Retrieves the current HuggingFace credentials.
57
58 :return: The current credentials used for HuggingFace authentication.
59 """
60 return Credentials(
61 access_key="",
62 secret_key="",
63 token=self.token,
64 expiration=None,
65 )
66
[docs]
67 def refresh_credentials(self) -> None:
68 """
69 Refreshes the credentials if they are expired or about to expire.
70
71 Note: HuggingFace tokens typically don't expire, so this is a no-op.
72 """
73 pass
74
75
[docs]
76class HuggingFaceStorageProvider(BaseStorageProvider):
77 """
78 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with HuggingFace Hub repositories.
79 """
80
81 def __init__(
82 self,
83 repository_id: str,
84 repo_type: str = "model",
85 base_path: str = "",
86 repo_revision: str = "main",
87 credentials_provider: Optional[CredentialsProvider] = None,
88 config_dict: Optional[dict[str, Any]] = None,
89 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
90 ):
91 """
92 Initializes the :py:class:`HuggingFaceStorageProvider` with repository information and optional credentials provider.
93
94 :param repository_id: The HuggingFace repository ID (e.g., 'username/repo-name').
95 :param repo_type: The type of repository ('dataset', 'model', 'space'). Defaults to 'model'.
96 :param base_path: The root prefix path within the repository where all operations will be scoped.
97 :param repo_revision: The git revision (branch, tag, or commit) to use. Defaults to 'main'.
98 :param credentials_provider: The provider to retrieve HuggingFace credentials.
99 :param config_dict: Resolved MSC config.
100 :param telemetry_provider: A function that provides a telemetry instance.
101 """
102
103 # Validate repo_type
104 allowed_repo_types = {"dataset", "model", "space"}
105 if repo_type not in allowed_repo_types:
106 raise ValueError(f"Invalid repo_type '{repo_type}'. Must be one of: {allowed_repo_types}")
107
108 # Validate repository_id format
109 if not repository_id or "/" not in repository_id:
110 raise ValueError(f"Invalid repository_id '{repository_id}'. Expected format: 'username/repo-name'")
111
112 self._validate_hf_transfer_availability()
113
114 super().__init__(
115 base_path=base_path,
116 provider_name=PROVIDER,
117 config_dict=config_dict,
118 telemetry_provider=telemetry_provider,
119 )
120
121 self._repository_id = repository_id
122 self._repo_type = repo_type
123 self._repo_revision = repo_revision
124 self._credentials_provider = credentials_provider
125
126 self._hf_client: HfApi = self._create_hf_api_client()
127
128 def _create_hf_api_client(self) -> HfApi:
129 """
130 Creates and configures the HuggingFace API client.
131
132 Initializes the HfApi client with authentication token if credentials are provided,
133 otherwise creates an unauthenticated client for public repositories.
134
135 :return: Configured HfApi client instance.
136 """
137
138 token = None
139 if self._credentials_provider:
140 creds = self._credentials_provider.get_credentials()
141 token = creds.token
142
143 return HfApi(token=token)
144
145 def _validate_hf_transfer_availability(self) -> None:
146 """
147 Validates that hf_transfer is available if it's enabled via environment variables.
148
149 Raises:
150 ValueError: If hf_transfer is enabled but not available.
151 """
152 # Check if hf_transfer is enabled via environment variable
153 hf_transfer_enabled = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "").lower() in ("1", "on", "true", "yes")
154
155 if hf_transfer_enabled and importlib.util.find_spec("hf_transfer") is None:
156 raise ValueError(HF_TRANSFER_UNAVAILABLE_ERROR_MESSAGE)
157
158 def _put_object(
159 self,
160 path: str,
161 body: bytes,
162 if_match: Optional[str] = None,
163 if_none_match: Optional[str] = None,
164 attributes: Optional[dict[str, str]] = None,
165 ) -> int:
166 """
167 Uploads an object to the HuggingFace repository.
168
169 :param path: The path where the object will be stored in the repository.
170 :param body: The content of the object to store.
171 :param if_match: Optional ETag for conditional uploads (not supported by HuggingFace).
172 :param if_none_match: Optional ETag for conditional uploads (not supported by HuggingFace).
173 :param attributes: Optional attributes for the object (not supported by HuggingFace).
174 :return: Data size in bytes.
175 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur.
176 :raises ValueError: If client attempts to create a directory.
177 :raises ValueError: If conditional upload parameters are provided (not supported).
178 """
179 if not self._hf_client:
180 raise RuntimeError("HuggingFace client not initialized")
181
182 if if_match is not None or if_none_match is not None:
183 raise ValueError(
184 "HuggingFace provider does not support conditional uploads. "
185 "if_match and if_none_match parameters are not supported."
186 )
187
188 if attributes is not None:
189 raise ValueError(
190 "HuggingFace provider does not support custom object attributes. "
191 "Use commit messages or repository metadata instead."
192 )
193
194 if path.endswith("/"):
195 raise ValueError(
196 "HuggingFace Storage Provider does not support explicit directory creation. "
197 "Directories are created implicitly when files are uploaded to paths within them."
198 )
199
200 path = self._normalize_path(path)
201
202 try:
203 with tempfile.NamedTemporaryFile(delete=False) as temp_file:
204 temp_file.write(body)
205 temp_file_path = temp_file.name
206
207 try:
208 self._hf_client.upload_file(
209 path_or_fileobj=temp_file_path,
210 path_in_repo=path,
211 repo_id=self._repository_id,
212 repo_type=self._repo_type,
213 revision=self._repo_revision,
214 commit_message=f"Upload {path}",
215 commit_description=None,
216 create_pr=False,
217 )
218
219 return len(body)
220
221 finally:
222 os.unlink(temp_file_path)
223
224 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
225 raise FileNotFoundError(
226 f"Repository or revision not found: {self._repository_id}@{self._repo_revision}"
227 ) from e
228 except HfHubHTTPError as e:
229 raise RuntimeError(f"HuggingFace API error during upload of {path}: {e}") from e
230 except Exception as e:
231 raise RuntimeError(f"Unexpected error during upload of {path}: {e}") from e
232
233 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
234 """
235 Retrieves an object from the HuggingFace repository.
236
237 :param path: The path of the object to retrieve from the repository.
238 :param byte_range: Optional byte range for partial content (not supported by HuggingFace).
239 :return: The content of the retrieved object.
240 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur.
241 :raises ValueError: If a byte range is requested (HuggingFace doesn't support range reads).
242 :raises FileNotFoundError: If the file doesn't exist in the repository.
243 """
244
245 if not self._hf_client:
246 raise RuntimeError("HuggingFace client not initialized")
247
248 if byte_range is not None:
249 raise ValueError(
250 "HuggingFace provider does not support partial range reads. "
251 f"Requested range: offset={byte_range.offset}, size={byte_range.size}. "
252 "To read the entire file, call get_object() without the byte_range parameter."
253 )
254
255 path = self._normalize_path(path)
256
257 try:
258 with tempfile.TemporaryDirectory() as temp_dir:
259 downloaded_path = self._hf_client.hf_hub_download(
260 repo_id=self._repository_id,
261 filename=path,
262 repo_type=self._repo_type,
263 revision=self._repo_revision,
264 local_dir=temp_dir,
265 )
266
267 with open(downloaded_path, "rb") as f:
268 data = f.read()
269
270 return data
271
272 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
273 raise FileNotFoundError(f"File not found in HuggingFace repository: {path}") from e
274 except HfHubHTTPError as e:
275 raise RuntimeError(f"HuggingFace API error during download of {path}: {e}") from e
276 except Exception as e:
277 raise RuntimeError(f"Unexpected error during download of {path}: {e}") from e
278
279 def _copy_object(self, src_path: str, dest_path: str) -> int:
280 """
281 Copies an object within the HuggingFace repository using server-side copy.
282
283 .. note::
284 Copy behavior is size-dependent: files ≥10MB are copied remotely via
285 metadata (LFS), while files <10MB are downloaded and re-uploaded.
286
287 :param src_path: The source path of the object to copy.
288 :param dest_path: The destination path for the copied object.
289 :return: Data size in bytes.
290 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur.
291 :raises FileNotFoundError: If the source file doesn't exist.
292 """
293 if not self._hf_client:
294 raise RuntimeError("HuggingFace client not initialized")
295
296 src_path = self._normalize_path(src_path)
297 dest_path = self._normalize_path(dest_path)
298
299 src_object = self._get_object_metadata(src_path)
300
301 try:
302 operations = [
303 CommitOperationCopy(
304 src_path_in_repo=src_path,
305 path_in_repo=dest_path,
306 )
307 ]
308
309 self._hf_client.create_commit(
310 repo_id=self._repository_id,
311 operations=operations,
312 commit_message=f"Copy {src_path} to {dest_path}",
313 repo_type=self._repo_type,
314 revision=self._repo_revision,
315 )
316
317 return src_object.content_length
318
319 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
320 raise FileNotFoundError(
321 f"Repository or revision not found: {self._repository_id}@{self._repo_revision}"
322 ) from e
323 except HfHubHTTPError as e:
324 raise RuntimeError(f"HuggingFace API error during copy from {src_path} to {dest_path}: {e}") from e
325 except Exception as e:
326 raise RuntimeError(f"Unexpected error during copy from {src_path} to {dest_path}: {e}") from e
327
328 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
329 """
330 Deletes an object from the HuggingFace repository.
331
332 :param path: The path of the object to delete from the repository.
333 :param if_match: Optional ETag for conditional deletion (not supported by HuggingFace).
334 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur.
335 :raises ValueError: If conditional deletion parameters are provided (not supported).
336 :raises FileNotFoundError: If the file doesn't exist in the repository.
337 """
338 if not self._hf_client:
339 raise RuntimeError("HuggingFace client not initialized")
340
341 if if_match is not None:
342 raise ValueError(
343 "HuggingFace provider does not support conditional deletion. if_match parameter is not supported."
344 )
345
346 path = self._normalize_path(path)
347
348 try:
349 self._hf_client.delete_file(
350 path_in_repo=path,
351 repo_id=self._repository_id,
352 repo_type=self._repo_type,
353 revision=self._repo_revision,
354 commit_message=f"Delete {path}",
355 )
356
357 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
358 raise FileNotFoundError(
359 f"Repository or revision not found: {self._repository_id}@{self._repo_revision}"
360 ) from e
361 except Exception as e:
362 raise RuntimeError(f"Unexpected error during deletion of {path}: {e}") from e
363
364 def _item_to_metadata(self, item: Union[RepoFile, RepoFolder]) -> ObjectMetadata:
365 """
366 Convert a RepoFile or RepoFolder into ObjectMetadata.
367
368 :param item: The RepoFile or RepoFolder item from HuggingFace API.
369 :return: ObjectMetadata representing the item.
370 """
371 last_modified = AWARE_DATETIME_MIN
372
373 if isinstance(item, RepoFile):
374 etag = item.blob_id
375 return ObjectMetadata(
376 key=item.path,
377 type="file",
378 content_length=item.size,
379 last_modified=last_modified,
380 etag=etag,
381 content_type=None,
382 storage_class=None,
383 metadata=None,
384 )
385 else:
386 etag = item.tree_id
387 return ObjectMetadata(
388 key=item.path,
389 type="directory",
390 content_length=0,
391 last_modified=last_modified,
392 etag=etag,
393 content_type=None,
394 storage_class=None,
395 metadata=None,
396 )
397
398 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
399 """
400 Retrieves metadata for an object in the HuggingFace repository.
401
402 :param path: The path of the object to get metadata for.
403 :param strict: Whether to raise an error if the object doesn't exist.
404 :return: Metadata about the object.
405 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur.
406 :raises FileNotFoundError: If the file doesn't exist and strict=True.
407 """
408 if not self._hf_client:
409 raise RuntimeError("HuggingFace client not initialized")
410
411 path = self._normalize_path(path)
412
413 try:
414 items = self._hf_client.get_paths_info(
415 repo_id=self._repository_id,
416 paths=[path],
417 repo_type=self._repo_type,
418 revision=self._repo_revision,
419 expand=True,
420 )
421
422 if not items:
423 raise FileNotFoundError(f"File not found in HuggingFace repository: {path}")
424
425 item = items[0]
426 return self._item_to_metadata(item)
427 except FileNotFoundError as error:
428 if strict:
429 dir_path = path.rstrip("/") + "/"
430 if self._is_dir(dir_path):
431 return ObjectMetadata(
432 key=dir_path,
433 type="directory",
434 content_length=0,
435 last_modified=AWARE_DATETIME_MIN,
436 etag=None,
437 content_type=None,
438 storage_class=None,
439 metadata=None,
440 )
441 raise error
442 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
443 raise FileNotFoundError(
444 f"Repository or revision not found: {self._repository_id}@{self._repo_revision}"
445 ) from e
446 except Exception as e:
447 raise RuntimeError(f"Unexpected error getting metadata for {path}: {e}") from e
448
449 def _list_objects(
450 self,
451 path: str,
452 start_after: Optional[str] = None,
453 end_at: Optional[str] = None,
454 include_directories: bool = False,
455 follow_symlinks: bool = True,
456 ) -> Iterator[ObjectMetadata]:
457 """
458 Lists objects in the HuggingFace repository under the specified path.
459
460 :param path: The path to list objects under.
461 :param start_after: The key to start listing after (not supported by HuggingFace).
462 :param end_at: The key to end listing at (not supported by HuggingFace).
463 :param include_directories: Whether to include directories in the listing.
464 :return: An iterator over object metadata for objects under the specified path.
465 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur.
466 :raises ValueError: If start_after or end_at parameters are provided (not supported).
467 """
468 if not self._hf_client:
469 raise RuntimeError("HuggingFace client not initialized")
470
471 if start_after is not None or end_at is not None:
472 raise ValueError(
473 "HuggingFace provider does not support pagination with start_after or end_at parameters. "
474 "These parameters are not supported by the HuggingFace Hub API."
475 )
476
477 path = self._normalize_path(path)
478
479 try:
480 metadata = self._get_object_metadata(path.rstrip("/"), strict=False)
481 if metadata and metadata.type == "file":
482 yield metadata
483 return
484 except FileNotFoundError:
485 pass
486
487 try:
488 dir_path = path.rstrip("/")
489
490 repo_items = self._hf_client.list_repo_tree(
491 repo_id=self._repository_id,
492 path_in_repo=dir_path + "/" if dir_path else None,
493 repo_type=self._repo_type,
494 revision=self._repo_revision,
495 expand=True,
496 recursive=not include_directories,
497 )
498
499 for item in repo_items:
500 if include_directories and isinstance(item, RepoFolder):
501 yield self._item_to_metadata(item)
502 elif isinstance(item, RepoFile):
503 yield self._item_to_metadata(item)
504
505 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
506 raise FileNotFoundError(
507 f"Repository or revision not found: {self._repository_id}@{self._repo_revision}"
508 ) from e
509 except EntryNotFoundError:
510 # Directory doesn't exist - return empty (matches POSIX behavior)
511 pass
512 except HfHubHTTPError as e:
513 raise RuntimeError(f"HuggingFace API error during listing of {path}: {e}") from e
514 except Exception as e:
515 raise RuntimeError(f"Unexpected error during listing of {path}: {e}") from e
516
517 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
518 """
519 Uploads a file to the HuggingFace repository.
520
521 :param remote_path: The remote path where the file will be stored in the repository.
522 :param f: File path or file object to upload.
523 :param attributes: Optional attributes for the file (not supported by HuggingFace).
524 :return: Data size in bytes.
525 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur.
526 :raises ValueError: If client attempts to create a directory.
527 :raises ValueError: If custom attributes are provided (not supported).
528 """
529 if not self._hf_client:
530 raise RuntimeError("HuggingFace client not initialized")
531
532 if attributes is not None:
533 raise ValueError(
534 "HuggingFace provider does not support custom file attributes. "
535 "Use commit messages or repository metadata instead."
536 )
537
538 if remote_path.endswith("/"):
539 raise ValueError(
540 "HuggingFace Storage Provider does not support explicit directory creation. "
541 "Directories are created implicitly when files are uploaded to paths within them."
542 )
543
544 remote_path = self._normalize_path(remote_path)
545
546 try:
547 if isinstance(f, str):
548 file_size = os.path.getsize(f)
549
550 self._hf_client.upload_file(
551 path_or_fileobj=f,
552 path_in_repo=remote_path,
553 repo_id=self._repository_id,
554 repo_type=self._repo_type,
555 revision=self._repo_revision,
556 commit_message=f"Upload {remote_path}",
557 commit_description=None,
558 create_pr=False,
559 )
560
561 return file_size
562
563 else:
564 content = f.read()
565
566 if isinstance(content, str):
567 content_bytes = content.encode("utf-8")
568 else:
569 content_bytes = content
570
571 # Create temporary file since HfAPI.upload_file requires BinaryIO, not generic IO
572 with tempfile.NamedTemporaryFile(delete=False) as temp_file:
573 temp_file.write(content_bytes)
574 temp_file_path = temp_file.name
575
576 try:
577 self._hf_client.upload_file(
578 path_or_fileobj=temp_file_path,
579 path_in_repo=remote_path,
580 repo_id=self._repository_id,
581 repo_type=self._repo_type,
582 revision=self._repo_revision,
583 commit_message=f"Upload {remote_path}",
584 create_pr=False,
585 )
586
587 return len(content_bytes)
588
589 finally:
590 os.unlink(temp_file_path)
591
592 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
593 raise FileNotFoundError(
594 f"Repository or revision not found: {self._repository_id}@{self._repo_revision}"
595 ) from e
596 except HfHubHTTPError as e:
597 raise RuntimeError(f"HuggingFace API error during upload of {remote_path}: {e}") from e
598 except Exception as e:
599 raise RuntimeError(f"Unexpected error during upload of {remote_path}: {e}") from e
600
601 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
602 """
603 Downloads a file from the HuggingFace repository.
604
605 :param remote_path: The remote path of the file to download from the repository.
606 :param f: Local file path or file object to write to.
607 :param metadata: Optional object metadata (not used in this implementation).
608 :return: Data size in bytes.
609 """
610 if not self._hf_client:
611 raise RuntimeError("HuggingFace client not initialized")
612
613 remote_path = self._normalize_path(remote_path)
614
615 try:
616 if isinstance(f, str):
617 parent_dir = os.path.dirname(f)
618 if parent_dir:
619 os.makedirs(parent_dir, exist_ok=True)
620
621 target_dir = parent_dir if parent_dir else "."
622 downloaded_path = self._hf_client.hf_hub_download(
623 repo_id=self._repository_id,
624 filename=remote_path,
625 repo_type=self._repo_type,
626 revision=self._repo_revision,
627 local_dir=target_dir,
628 )
629
630 if os.path.abspath(downloaded_path) != os.path.abspath(f):
631 os.rename(downloaded_path, f)
632
633 return os.path.getsize(f)
634
635 else:
636 with tempfile.TemporaryDirectory() as temp_dir:
637 downloaded_path = self._hf_client.hf_hub_download(
638 repo_id=self._repository_id,
639 filename=remote_path,
640 repo_type=self._repo_type,
641 revision=self._repo_revision,
642 local_dir=temp_dir,
643 )
644
645 with open(downloaded_path, "rb") as src:
646 data = src.read()
647 if isinstance(f, io.TextIOBase):
648 f.write(data.decode("utf-8"))
649 else:
650 f.write(data)
651
652 return len(data)
653
654 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
655 raise FileNotFoundError(f"File not found in HuggingFace repository: {remote_path}") from e
656 except HfHubHTTPError as e:
657 raise RuntimeError(f"HuggingFace API error during download: {e}") from e
658 except Exception as e:
659 raise RuntimeError(f"Unexpected error during download: {e}") from e
660
661 def _is_dir(self, path: str) -> bool:
662 """
663 Helper method to check if a path is a directory.
664
665 :param path: The path to check.
666 :return: True if the path appears to be a directory (has files under it).
667 """
668 path = path.rstrip("/")
669 if not path:
670 # The root of the repo is always a directory
671 return True
672
673 try:
674 path_info = self._hf_client.get_paths_info(
675 repo_id=self._repository_id,
676 paths=[path],
677 repo_type=self._repo_type,
678 revision=self._repo_revision,
679 )[0]
680
681 return isinstance(path_info, RepoFolder)
682
683 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
684 raise FileNotFoundError(
685 f"Repository or revision not found: {self._repository_id}@{self._repo_revision}"
686 ) from e
687 except IndexError:
688 return False
689 except Exception as e:
690 raise Exception(f"Unexpected error: {e}")
691
692 def _normalize_path(self, path: str) -> str:
693 """
694 Normalize path for HuggingFace API by removing leading slashes.
695 HuggingFace expects relative paths within repositories.
696 """
697 return path.lstrip("/")