1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import importlib.util
17import io
18import os
19import tempfile
20from collections.abc import Callable, Iterator
21from typing import IO, Any, Optional, Union
22
23from huggingface_hub import CommitOperationCopy, HfApi
24from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError, RepositoryNotFoundError, RevisionNotFoundError
25from huggingface_hub.hf_api import RepoFile, RepoFolder
26
27from ..telemetry import Telemetry
28from ..types import AWARE_DATETIME_MIN, Credentials, CredentialsProvider, ObjectMetadata, Range
29from .base import BaseStorageProvider
30
31PROVIDER = "huggingface"
32
33HF_TRANSFER_UNAVAILABLE_ERROR_MESSAGE = (
34 "Fast transfer using 'hf_transfer' is enabled (HF_HUB_ENABLE_HF_TRANSFER=1) "
35 "but 'hf_transfer' package is not available in your environment. "
36 "Either install hf_transfer with 'pip install hf_transfer' or "
37 "disable it by setting HF_HUB_ENABLE_HF_TRANSFER=0"
38)
39
40
[docs]
41class HuggingFaceCredentialsProvider(CredentialsProvider):
42 """
43 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides HuggingFace credentials.
44 """
45
46 def __init__(self, access_token: str):
47 """
48 Initializes the :py:class:`HuggingFaceCredentialsProvider` with the provided access token.
49
50 :param access_token: The HuggingFace access token for authentication.
51 """
52 self.token = access_token
53
[docs]
54 def get_credentials(self) -> Credentials:
55 """
56 Retrieves the current HuggingFace credentials.
57
58 :return: The current credentials used for HuggingFace authentication.
59 """
60 return Credentials(
61 access_key="",
62 secret_key="",
63 token=self.token,
64 expiration=None,
65 )
66
[docs]
67 def refresh_credentials(self) -> None:
68 """
69 Refreshes the credentials if they are expired or about to expire.
70
71 Note: HuggingFace tokens typically don't expire, so this is a no-op.
72 """
73 pass
74
75
[docs]
76class HuggingFaceStorageProvider(BaseStorageProvider):
77 """
78 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with HuggingFace Hub repositories.
79 """
80
81 def __init__(
82 self,
83 repository_id: str,
84 repo_type: str = "model",
85 base_path: str = "",
86 repo_revision: str = "main",
87 credentials_provider: Optional[CredentialsProvider] = None,
88 config_dict: Optional[dict[str, Any]] = None,
89 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
90 ):
91 """
92 Initializes the :py:class:`HuggingFaceStorageProvider` with repository information and optional credentials provider.
93
94 :param repository_id: The HuggingFace repository ID (e.g., 'username/repo-name').
95 :param repo_type: The type of repository ('dataset', 'model', 'space'). Defaults to 'model'.
96 :param base_path: The root prefix path within the repository where all operations will be scoped.
97 :param repo_revision: The git revision (branch, tag, or commit) to use. Defaults to 'main'.
98 :param credentials_provider: The provider to retrieve HuggingFace credentials.
99 :param config_dict: Resolved MSC config.
100 :param telemetry_provider: A function that provides a telemetry instance.
101 """
102
103 # Validate repo_type
104 allowed_repo_types = {"dataset", "model", "space"}
105 if repo_type not in allowed_repo_types:
106 raise ValueError(f"Invalid repo_type '{repo_type}'. Must be one of: {allowed_repo_types}")
107
108 # Validate repository_id format
109 if not repository_id or "/" not in repository_id:
110 raise ValueError(f"Invalid repository_id '{repository_id}'. Expected format: 'username/repo-name'")
111
112 self._validate_hf_transfer_availability()
113
114 super().__init__(
115 base_path=base_path,
116 provider_name=PROVIDER,
117 config_dict=config_dict,
118 telemetry_provider=telemetry_provider,
119 )
120
121 self._repository_id = repository_id
122 self._repo_type = repo_type
123 self._repo_revision = repo_revision
124 self._credentials_provider = credentials_provider
125
126 self._hf_client: HfApi = self._create_hf_api_client()
127
128 def _create_hf_api_client(self) -> HfApi:
129 """
130 Creates and configures the HuggingFace API client.
131
132 Initializes the HfApi client with authentication token if credentials are provided,
133 otherwise creates an unauthenticated client for public repositories.
134
135 :return: Configured HfApi client instance.
136 """
137
138 token = None
139 if self._credentials_provider:
140 creds = self._credentials_provider.get_credentials()
141 token = creds.token
142
143 return HfApi(token=token)
144
145 def _validate_hf_transfer_availability(self) -> None:
146 """
147 Validates that hf_transfer is available if it's enabled via environment variables.
148
149 Raises:
150 ValueError: If hf_transfer is enabled but not available.
151 """
152 # Check if hf_transfer is enabled via environment variable
153 hf_transfer_enabled = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "").lower() in ("1", "on", "true", "yes")
154
155 if hf_transfer_enabled and importlib.util.find_spec("hf_transfer") is None:
156 raise ValueError(HF_TRANSFER_UNAVAILABLE_ERROR_MESSAGE)
157
158 def _put_object(
159 self,
160 path: str,
161 body: bytes,
162 if_match: Optional[str] = None,
163 if_none_match: Optional[str] = None,
164 attributes: Optional[dict[str, str]] = None,
165 ) -> int:
166 """
167 Uploads an object to the HuggingFace repository.
168
169 :param path: The path where the object will be stored in the repository.
170 :param body: The content of the object to store.
171 :param if_match: Optional ETag for conditional uploads (not supported by HuggingFace).
172 :param if_none_match: Optional ETag for conditional uploads (not supported by HuggingFace).
173 :param attributes: Optional attributes for the object (not supported by HuggingFace).
174 :return: Data size in bytes.
175 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur.
176 :raises ValueError: If client attempts to create a directory.
177 :raises ValueError: If conditional upload parameters are provided (not supported).
178 """
179 if not self._hf_client:
180 raise RuntimeError("HuggingFace client not initialized")
181
182 if if_match is not None or if_none_match is not None:
183 raise ValueError(
184 "HuggingFace provider does not support conditional uploads. "
185 "if_match and if_none_match parameters are not supported."
186 )
187
188 if attributes is not None:
189 raise ValueError(
190 "HuggingFace provider does not support custom object attributes. "
191 "Use commit messages or repository metadata instead."
192 )
193
194 if path.endswith("/"):
195 raise ValueError(
196 "HuggingFace Storage Provider does not support explicit directory creation. "
197 "Directories are created implicitly when files are uploaded to paths within them."
198 )
199
200 path = self._normalize_path(path)
201
202 try:
203 with tempfile.NamedTemporaryFile(delete=False) as temp_file:
204 temp_file.write(body)
205 temp_file_path = temp_file.name
206
207 try:
208 self._hf_client.upload_file(
209 path_or_fileobj=temp_file_path,
210 path_in_repo=path,
211 repo_id=self._repository_id,
212 repo_type=self._repo_type,
213 revision=self._repo_revision,
214 commit_message=f"Upload {path}",
215 commit_description=None,
216 create_pr=False,
217 )
218
219 return len(body)
220
221 finally:
222 os.unlink(temp_file_path)
223
224 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
225 raise FileNotFoundError(
226 f"Repository or revision not found: {self._repository_id}@{self._repo_revision}"
227 ) from e
228 except HfHubHTTPError as e:
229 raise RuntimeError(f"HuggingFace API error during upload of {path}: {e}") from e
230 except Exception as e:
231 raise RuntimeError(f"Unexpected error during upload of {path}: {e}") from e
232
233 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
234 """
235 Retrieves an object from the HuggingFace repository.
236
237 :param path: The path of the object to retrieve from the repository.
238 :param byte_range: Optional byte range for partial content (not supported by HuggingFace).
239 :return: The content of the retrieved object.
240 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur.
241 :raises ValueError: If a byte range is requested (HuggingFace doesn't support range reads).
242 :raises FileNotFoundError: If the file doesn't exist in the repository.
243 """
244
245 if not self._hf_client:
246 raise RuntimeError("HuggingFace client not initialized")
247
248 if byte_range is not None:
249 raise ValueError(
250 "HuggingFace provider does not support partial range reads. "
251 f"Requested range: offset={byte_range.offset}, size={byte_range.size}. "
252 "To read the entire file, call get_object() without the byte_range parameter."
253 )
254
255 path = self._normalize_path(path)
256
257 try:
258 with tempfile.TemporaryDirectory() as temp_dir:
259 downloaded_path = self._hf_client.hf_hub_download(
260 repo_id=self._repository_id,
261 filename=path,
262 repo_type=self._repo_type,
263 revision=self._repo_revision,
264 local_dir=temp_dir,
265 )
266
267 with open(downloaded_path, "rb") as f:
268 data = f.read()
269
270 return data
271
272 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
273 raise FileNotFoundError(f"File not found in HuggingFace repository: {path}") from e
274 except HfHubHTTPError as e:
275 raise RuntimeError(f"HuggingFace API error during download of {path}: {e}") from e
276 except Exception as e:
277 raise RuntimeError(f"Unexpected error during download of {path}: {e}") from e
278
279 def _copy_object(self, src_path: str, dest_path: str) -> int:
280 """
281 Copies an object within the HuggingFace repository using server-side copy.
282
283 .. note::
284 Copy behavior is size-dependent: files ≥10MB are copied remotely via
285 metadata (LFS), while files <10MB are downloaded and re-uploaded.
286
287 :param src_path: The source path of the object to copy.
288 :param dest_path: The destination path for the copied object.
289 :return: Data size in bytes.
290 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur.
291 :raises FileNotFoundError: If the source file doesn't exist.
292 """
293 if not self._hf_client:
294 raise RuntimeError("HuggingFace client not initialized")
295
296 src_path = self._normalize_path(src_path)
297 dest_path = self._normalize_path(dest_path)
298
299 src_object = self._get_object_metadata(src_path)
300
301 try:
302 operations = [
303 CommitOperationCopy(
304 src_path_in_repo=src_path,
305 path_in_repo=dest_path,
306 )
307 ]
308
309 self._hf_client.create_commit(
310 repo_id=self._repository_id,
311 operations=operations,
312 commit_message=f"Copy {src_path} to {dest_path}",
313 repo_type=self._repo_type,
314 revision=self._repo_revision,
315 )
316
317 return src_object.content_length
318
319 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
320 raise FileNotFoundError(
321 f"Repository or revision not found: {self._repository_id}@{self._repo_revision}"
322 ) from e
323 except HfHubHTTPError as e:
324 raise RuntimeError(f"HuggingFace API error during copy from {src_path} to {dest_path}: {e}") from e
325 except Exception as e:
326 raise RuntimeError(f"Unexpected error during copy from {src_path} to {dest_path}: {e}") from e
327
328 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
329 """
330 Deletes an object from the HuggingFace repository.
331
332 :param path: The path of the object to delete from the repository.
333 :param if_match: Optional ETag for conditional deletion (not supported by HuggingFace).
334 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur.
335 :raises ValueError: If conditional deletion parameters are provided (not supported).
336 :raises FileNotFoundError: If the file doesn't exist in the repository.
337 """
338 if not self._hf_client:
339 raise RuntimeError("HuggingFace client not initialized")
340
341 if if_match is not None:
342 raise ValueError(
343 "HuggingFace provider does not support conditional deletion. if_match parameter is not supported."
344 )
345
346 path = self._normalize_path(path)
347
348 try:
349 self._hf_client.delete_file(
350 path_in_repo=path,
351 repo_id=self._repository_id,
352 repo_type=self._repo_type,
353 revision=self._repo_revision,
354 commit_message=f"Delete {path}",
355 )
356
357 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
358 raise FileNotFoundError(
359 f"Repository or revision not found: {self._repository_id}@{self._repo_revision}"
360 ) from e
361 except Exception as e:
362 raise RuntimeError(f"Unexpected error during deletion of {path}: {e}") from e
363
364 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
365 """
366 Retrieves metadata for an object in the HuggingFace repository.
367
368 :param path: The path of the object to get metadata for.
369 :param strict: Whether to raise an error if the object doesn't exist.
370 :return: Metadata about the object.
371 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur.
372 :raises FileNotFoundError: If the file doesn't exist and strict=True.
373 """
374 if not self._hf_client:
375 raise RuntimeError("HuggingFace client not initialized")
376
377 path = self._normalize_path(path)
378
379 try:
380 items = self._hf_client.get_paths_info(
381 repo_id=self._repository_id,
382 paths=[path],
383 repo_type=self._repo_type,
384 revision=self._repo_revision,
385 expand=True,
386 )
387
388 if not items:
389 raise FileNotFoundError(f"File not found in HuggingFace repository: {path}")
390
391 item = items[0]
392
393 last_modified = AWARE_DATETIME_MIN
394 if hasattr(item, "last_commit") and item.last_commit:
395 last_modified = item.last_commit.date
396
397 return ObjectMetadata(
398 key=item.path,
399 type="file" if isinstance(item, RepoFile) else "directory",
400 content_length=item.size if isinstance(item, RepoFile) else 0,
401 last_modified=last_modified,
402 etag=None, # Can be obtained via a separate call to get_hf_file_metadata
403 content_type=None,
404 storage_class=None,
405 metadata=None,
406 )
407 except FileNotFoundError as error:
408 if strict:
409 dir_path = path.rstrip("/") + "/"
410 if self._is_dir(dir_path):
411 return ObjectMetadata(
412 key=dir_path,
413 type="directory",
414 content_length=0,
415 last_modified=AWARE_DATETIME_MIN,
416 etag=None,
417 content_type=None,
418 storage_class=None,
419 metadata=None,
420 )
421 raise error
422 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
423 raise FileNotFoundError(
424 f"Repository or revision not found: {self._repository_id}@{self._repo_revision}"
425 ) from e
426 except Exception as e:
427 raise RuntimeError(f"Unexpected error getting metadata for {path}: {e}") from e
428
429 def _list_objects(
430 self,
431 path: str,
432 start_after: Optional[str] = None,
433 end_at: Optional[str] = None,
434 include_directories: bool = False,
435 ) -> Iterator[ObjectMetadata]:
436 """
437 Lists objects in the HuggingFace repository under the specified path.
438
439 :param path: The path to list objects under.
440 :param start_after: The key to start listing after (not supported by HuggingFace).
441 :param end_at: The key to end listing at (not supported by HuggingFace).
442 :param include_directories: Whether to include directories in the listing.
443 :return: An iterator over object metadata for objects under the specified path.
444 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur.
445 :raises ValueError: If start_after or end_at parameters are provided (not supported).
446 """
447 if not self._hf_client:
448 raise RuntimeError("HuggingFace client not initialized")
449
450 if start_after is not None or end_at is not None:
451 raise ValueError(
452 "HuggingFace provider does not support pagination with start_after or end_at parameters. "
453 "These parameters are not supported by the HuggingFace Hub API."
454 )
455
456 path = self._normalize_path(path)
457
458 try:
459 repo_items = self._hf_client.list_repo_tree(
460 repo_id=self._repository_id,
461 path_in_repo=os.path.dirname(path),
462 repo_type=self._repo_type,
463 revision=self._repo_revision,
464 expand=True,
465 recursive=True,
466 )
467
468 for item in repo_items:
469 if not item.path.startswith(os.path.dirname(path)):
470 continue
471
472 if include_directories and isinstance(item, RepoFolder):
473 last_modified = AWARE_DATETIME_MIN
474 if hasattr(item, "last_commit") and item.last_commit:
475 last_modified = item.last_commit.date
476
477 yield ObjectMetadata(
478 key=item.path,
479 type="directory",
480 content_length=0,
481 last_modified=last_modified,
482 etag=None,
483 content_type=None,
484 storage_class=None,
485 metadata=None,
486 )
487
488 elif isinstance(item, RepoFile):
489 last_modified = AWARE_DATETIME_MIN
490 if hasattr(item, "last_commit") and item.last_commit:
491 last_modified = item.last_commit.date
492
493 yield ObjectMetadata(
494 key=item.path,
495 type="file",
496 content_length=item.size,
497 last_modified=last_modified,
498 etag=None,
499 content_type=None,
500 storage_class=None,
501 metadata=None,
502 )
503 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
504 raise FileNotFoundError(
505 f"Repository or revision not found: {self._repository_id}@{self._repo_revision}"
506 ) from e
507 # if an entry wasn't found, we can effectively treat this as an empty directory since HF creates/deletes directories implicitly.
508 except EntryNotFoundError:
509 pass
510 except HfHubHTTPError as e:
511 raise RuntimeError(f"HuggingFace API error during listing of {path}: {e}") from e
512 except Exception as e:
513 raise RuntimeError(f"Unexpected error during listing of {path}: {e}") from e
514
515 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
516 """
517 Uploads a file to the HuggingFace repository.
518
519 :param remote_path: The remote path where the file will be stored in the repository.
520 :param f: File path or file object to upload.
521 :param attributes: Optional attributes for the file (not supported by HuggingFace).
522 :return: Data size in bytes.
523 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur.
524 :raises ValueError: If client attempts to create a directory.
525 :raises ValueError: If custom attributes are provided (not supported).
526 """
527 if not self._hf_client:
528 raise RuntimeError("HuggingFace client not initialized")
529
530 if attributes is not None:
531 raise ValueError(
532 "HuggingFace provider does not support custom file attributes. "
533 "Use commit messages or repository metadata instead."
534 )
535
536 if remote_path.endswith("/"):
537 raise ValueError(
538 "HuggingFace Storage Provider does not support explicit directory creation. "
539 "Directories are created implicitly when files are uploaded to paths within them."
540 )
541
542 remote_path = self._normalize_path(remote_path)
543
544 try:
545 if isinstance(f, str):
546 file_size = os.path.getsize(f)
547
548 self._hf_client.upload_file(
549 path_or_fileobj=f,
550 path_in_repo=remote_path,
551 repo_id=self._repository_id,
552 repo_type=self._repo_type,
553 revision=self._repo_revision,
554 commit_message=f"Upload {remote_path}",
555 commit_description=None,
556 create_pr=False,
557 )
558
559 return file_size
560
561 else:
562 content = f.read()
563
564 if isinstance(content, str):
565 content_bytes = content.encode("utf-8")
566 else:
567 content_bytes = content
568
569 # Create temporary file since HfAPI.upload_file requires BinaryIO, not generic IO
570 with tempfile.NamedTemporaryFile(delete=False) as temp_file:
571 temp_file.write(content_bytes)
572 temp_file_path = temp_file.name
573
574 try:
575 self._hf_client.upload_file(
576 path_or_fileobj=temp_file_path,
577 path_in_repo=remote_path,
578 repo_id=self._repository_id,
579 repo_type=self._repo_type,
580 revision=self._repo_revision,
581 commit_message=f"Upload {remote_path}",
582 create_pr=False,
583 )
584
585 return len(content_bytes)
586
587 finally:
588 os.unlink(temp_file_path)
589
590 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
591 raise FileNotFoundError(
592 f"Repository or revision not found: {self._repository_id}@{self._repo_revision}"
593 ) from e
594 except HfHubHTTPError as e:
595 raise RuntimeError(f"HuggingFace API error during upload of {remote_path}: {e}") from e
596 except Exception as e:
597 raise RuntimeError(f"Unexpected error during upload of {remote_path}: {e}") from e
598
599 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
600 """
601 Downloads a file from the HuggingFace repository.
602
603 :param remote_path: The remote path of the file to download from the repository.
604 :param f: Local file path or file object to write to.
605 :param metadata: Optional object metadata (not used in this implementation).
606 :return: Data size in bytes.
607 """
608 if not self._hf_client:
609 raise RuntimeError("HuggingFace client not initialized")
610
611 remote_path = self._normalize_path(remote_path)
612
613 try:
614 if isinstance(f, str):
615 parent_dir = os.path.dirname(f)
616 if parent_dir:
617 os.makedirs(parent_dir, exist_ok=True)
618
619 target_dir = parent_dir if parent_dir else "."
620 downloaded_path = self._hf_client.hf_hub_download(
621 repo_id=self._repository_id,
622 filename=remote_path,
623 repo_type=self._repo_type,
624 revision=self._repo_revision,
625 local_dir=target_dir,
626 )
627
628 if os.path.abspath(downloaded_path) != os.path.abspath(f):
629 os.rename(downloaded_path, f)
630
631 return os.path.getsize(f)
632
633 else:
634 with tempfile.TemporaryDirectory() as temp_dir:
635 downloaded_path = self._hf_client.hf_hub_download(
636 repo_id=self._repository_id,
637 filename=remote_path,
638 repo_type=self._repo_type,
639 revision=self._repo_revision,
640 local_dir=temp_dir,
641 )
642
643 with open(downloaded_path, "rb") as src:
644 data = src.read()
645 if isinstance(f, io.TextIOBase):
646 f.write(data.decode("utf-8"))
647 else:
648 f.write(data)
649
650 return len(data)
651
652 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
653 raise FileNotFoundError(f"File not found in HuggingFace repository: {remote_path}") from e
654 except HfHubHTTPError as e:
655 raise RuntimeError(f"HuggingFace API error during download: {e}") from e
656 except Exception as e:
657 raise RuntimeError(f"Unexpected error during download: {e}") from e
658
659 def _is_dir(self, path: str) -> bool:
660 """
661 Helper method to check if a path is a directory.
662
663 :param path: The path to check.
664 :return: True if the path appears to be a directory (has files under it).
665 """
666 path = path.rstrip("/")
667 if not path:
668 # The root of the repo is always a directory
669 return True
670
671 try:
672 path_info = self._hf_client.get_paths_info(
673 repo_id=self._repository_id,
674 paths=[path],
675 repo_type=self._repo_type,
676 revision=self._repo_revision,
677 )[0]
678
679 return isinstance(path_info, RepoFolder)
680
681 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
682 raise FileNotFoundError(
683 f"Repository or revision not found: {self._repository_id}@{self._repo_revision}"
684 ) from e
685 except IndexError:
686 return False
687 except Exception as e:
688 raise Exception(f"Unexpected error: {e}")
689
690 def _normalize_path(self, path: str) -> str:
691 """
692 Normalize path for HuggingFace API by removing leading slashes.
693 HuggingFace expects relative paths within repositories.
694 """
695 return path.lstrip("/")