1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import importlib.util
17import io
18import os
19import tempfile
20from collections.abc import Callable, Iterator
21from typing import IO, Any, Optional, Union
22
23from huggingface_hub import CommitOperationCopy, HfApi
24from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError, RepositoryNotFoundError, RevisionNotFoundError
25from huggingface_hub.hf_api import RepoFile, RepoFolder
26
27from ..telemetry import Telemetry
28from ..types import AWARE_DATETIME_MIN, Credentials, CredentialsProvider, ObjectMetadata, Range
29from .base import BaseStorageProvider
30
31PROVIDER = "huggingface"
32
33HF_TRANSFER_UNAVAILABLE_ERROR_MESSAGE = (
34 "Fast transfer using 'hf_transfer' is enabled (HF_HUB_ENABLE_HF_TRANSFER=1) "
35 "but 'hf_transfer' package is not available in your environment. "
36 "Either install hf_transfer with 'pip install hf_transfer' or "
37 "disable it by setting HF_HUB_ENABLE_HF_TRANSFER=0"
38)
39
40
[docs]
41class HuggingFaceCredentialsProvider(CredentialsProvider):
42 """
43 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides HuggingFace credentials.
44 """
45
46 def __init__(self, access_token: str):
47 """
48 Initializes the :py:class:`HuggingFaceCredentialsProvider` with the provided access token.
49
50 :param access_token: The HuggingFace access token for authentication.
51 """
52 self.token = access_token
53
[docs]
54 def get_credentials(self) -> Credentials:
55 """
56 Retrieves the current HuggingFace credentials.
57
58 :return: The current credentials used for HuggingFace authentication.
59 """
60 return Credentials(
61 access_key="",
62 secret_key="",
63 token=self.token,
64 expiration=None,
65 )
66
[docs]
67 def refresh_credentials(self) -> None:
68 """
69 Refreshes the credentials if they are expired or about to expire.
70
71 Note: HuggingFace tokens typically don't expire, so this is a no-op.
72 """
73 pass
74
75
[docs]
76class HuggingFaceStorageProvider(BaseStorageProvider):
77 """
78 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with HuggingFace Hub repositories.
79 """
80
81 def __init__(
82 self,
83 repository_id: str,
84 repo_type: str = "model",
85 base_path: str = "",
86 repo_revision: str = "main",
87 credentials_provider: Optional[CredentialsProvider] = None,
88 config_dict: Optional[dict[str, Any]] = None,
89 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
90 ):
91 """
92 Initializes the :py:class:`HuggingFaceStorageProvider` with repository information and optional credentials provider.
93
94 :param repository_id: The HuggingFace repository ID (e.g., 'username/repo-name').
95 :param repo_type: The type of repository ('dataset', 'model', 'space'). Defaults to 'model'.
96 :param base_path: The root prefix path within the repository where all operations will be scoped.
97 :param repo_revision: The git revision (branch, tag, or commit) to use. Defaults to 'main'.
98 :param credentials_provider: The provider to retrieve HuggingFace credentials.
99 :param config_dict: Resolved MSC config.
100 :param telemetry_provider: A function that provides a telemetry instance.
101 """
102
103 # Validate repo_type
104 allowed_repo_types = {"dataset", "model", "space"}
105 if repo_type not in allowed_repo_types:
106 raise ValueError(f"Invalid repo_type '{repo_type}'. Must be one of: {allowed_repo_types}")
107
108 # Validate repository_id format
109 if not repository_id or "/" not in repository_id:
110 raise ValueError(f"Invalid repository_id '{repository_id}'. Expected format: 'username/repo-name'")
111
112 self._validate_hf_transfer_availability()
113
114 super().__init__(
115 base_path=base_path,
116 provider_name=PROVIDER,
117 config_dict=config_dict,
118 telemetry_provider=telemetry_provider,
119 )
120
121 self._repository_id = repository_id
122 self._repo_type = repo_type
123 self._repo_revision = repo_revision
124 self._credentials_provider = credentials_provider
125
126 self._hf_client: HfApi = self._create_hf_api_client()
127
128 def _create_hf_api_client(self) -> HfApi:
129 """
130 Creates and configures the HuggingFace API client.
131
132 Initializes the HfApi client with authentication token if credentials are provided,
133 otherwise creates an unauthenticated client for public repositories.
134
135 :return: Configured HfApi client instance.
136 """
137
138 token = None
139 if self._credentials_provider:
140 creds = self._credentials_provider.get_credentials()
141 token = creds.token
142
143 return HfApi(token=token)
144
145 def _validate_hf_transfer_availability(self) -> None:
146 """
147 Validates that hf_transfer is available if it's enabled via environment variables.
148
149 Raises:
150 ValueError: If hf_transfer is enabled but not available.
151 """
152 # Check if hf_transfer is enabled via environment variable
153 hf_transfer_enabled = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "").lower() in ("1", "on", "true", "yes")
154
155 if hf_transfer_enabled and importlib.util.find_spec("hf_transfer") is None:
156 raise ValueError(HF_TRANSFER_UNAVAILABLE_ERROR_MESSAGE)
157
158 def _put_object(
159 self,
160 path: str,
161 body: bytes,
162 if_match: Optional[str] = None,
163 if_none_match: Optional[str] = None,
164 attributes: Optional[dict[str, str]] = None,
165 ) -> int:
166 """
167 Uploads an object to the HuggingFace repository.
168
169 :param path: The path where the object will be stored in the repository.
170 :param body: The content of the object to store.
171 :param if_match: Optional ETag for conditional uploads (not supported by HuggingFace).
172 :param if_none_match: Optional ETag for conditional uploads (not supported by HuggingFace).
173 :param attributes: Optional attributes for the object (not supported by HuggingFace).
174 :return: Data size in bytes.
175 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur.
176 :raises ValueError: If client attempts to create a directory.
177 :raises ValueError: If conditional upload parameters are provided (not supported).
178 """
179 if not self._hf_client:
180 raise RuntimeError("HuggingFace client not initialized")
181
182 if if_match is not None or if_none_match is not None:
183 raise ValueError(
184 "HuggingFace provider does not support conditional uploads. "
185 "if_match and if_none_match parameters are not supported."
186 )
187
188 if attributes is not None:
189 raise ValueError(
190 "HuggingFace provider does not support custom object attributes. "
191 "Use commit messages or repository metadata instead."
192 )
193
194 if path.endswith("/"):
195 raise ValueError(
196 "HuggingFace Storage Provider does not support explicit directory creation. "
197 "Directories are created implicitly when files are uploaded to paths within them."
198 )
199
200 path = self._normalize_path(path)
201
202 try:
203 with tempfile.NamedTemporaryFile(delete=False) as temp_file:
204 temp_file.write(body)
205 temp_file_path = temp_file.name
206
207 try:
208 self._hf_client.upload_file(
209 path_or_fileobj=temp_file_path,
210 path_in_repo=path,
211 repo_id=self._repository_id,
212 repo_type=self._repo_type,
213 revision=self._repo_revision,
214 commit_message=f"Upload {path}",
215 commit_description=None,
216 create_pr=False,
217 )
218
219 return len(body)
220
221 finally:
222 os.unlink(temp_file_path)
223
224 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
225 raise FileNotFoundError(
226 f"Repository or revision not found: {self._repository_id}@{self._repo_revision}"
227 ) from e
228 except HfHubHTTPError as e:
229 raise RuntimeError(f"HuggingFace API error during upload of {path}: {e}") from e
230 except Exception as e:
231 raise RuntimeError(f"Unexpected error during upload of {path}: {e}") from e
232
233 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
234 """
235 Retrieves an object from the HuggingFace repository.
236
237 :param path: The path of the object to retrieve from the repository.
238 :param byte_range: Optional byte range for partial content (not supported by HuggingFace).
239 :return: The content of the retrieved object.
240 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur.
241 :raises ValueError: If a byte range is requested (HuggingFace doesn't support range reads).
242 :raises FileNotFoundError: If the file doesn't exist in the repository.
243 """
244
245 if not self._hf_client:
246 raise RuntimeError("HuggingFace client not initialized")
247
248 if byte_range is not None:
249 raise ValueError(
250 "HuggingFace provider does not support partial range reads. "
251 f"Requested range: offset={byte_range.offset}, size={byte_range.size}. "
252 "To read the entire file, call get_object() without the byte_range parameter."
253 )
254
255 path = self._normalize_path(path)
256
257 try:
258 with tempfile.TemporaryDirectory() as temp_dir:
259 downloaded_path = self._hf_client.hf_hub_download(
260 repo_id=self._repository_id,
261 filename=path,
262 repo_type=self._repo_type,
263 revision=self._repo_revision,
264 local_dir=temp_dir,
265 )
266
267 with open(downloaded_path, "rb") as f:
268 data = f.read()
269
270 return data
271
272 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
273 raise FileNotFoundError(f"File not found in HuggingFace repository: {path}") from e
274 except HfHubHTTPError as e:
275 raise RuntimeError(f"HuggingFace API error during download of {path}: {e}") from e
276 except Exception as e:
277 raise RuntimeError(f"Unexpected error during download of {path}: {e}") from e
278
279 def _copy_object(self, src_path: str, dest_path: str) -> int:
280 """
281 Copies an object within the HuggingFace repository using server-side copy.
282
283 .. note::
284 Copy behavior is size-dependent: files ≥10MB are copied remotely via
285 metadata (LFS), while files <10MB are downloaded and re-uploaded.
286
287 :param src_path: The source path of the object to copy.
288 :param dest_path: The destination path for the copied object.
289 :return: Data size in bytes.
290 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur.
291 :raises FileNotFoundError: If the source file doesn't exist.
292 """
293 if not self._hf_client:
294 raise RuntimeError("HuggingFace client not initialized")
295
296 src_path = self._normalize_path(src_path)
297 dest_path = self._normalize_path(dest_path)
298
299 src_object = self._get_object_metadata(src_path)
300
301 try:
302 operations = [
303 CommitOperationCopy(
304 src_path_in_repo=src_path,
305 path_in_repo=dest_path,
306 )
307 ]
308
309 self._hf_client.create_commit(
310 repo_id=self._repository_id,
311 operations=operations,
312 commit_message=f"Copy {src_path} to {dest_path}",
313 repo_type=self._repo_type,
314 revision=self._repo_revision,
315 )
316
317 return src_object.content_length
318
319 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
320 raise FileNotFoundError(
321 f"Repository or revision not found: {self._repository_id}@{self._repo_revision}"
322 ) from e
323 except HfHubHTTPError as e:
324 raise RuntimeError(f"HuggingFace API error during copy from {src_path} to {dest_path}: {e}") from e
325 except Exception as e:
326 raise RuntimeError(f"Unexpected error during copy from {src_path} to {dest_path}: {e}") from e
327
328 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
329 """
330 Deletes an object from the HuggingFace repository.
331
332 :param path: The path of the object to delete from the repository.
333 :param if_match: Optional ETag for conditional deletion (not supported by HuggingFace).
334 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur.
335 :raises ValueError: If conditional deletion parameters are provided (not supported).
336 :raises FileNotFoundError: If the file doesn't exist in the repository.
337 """
338 if not self._hf_client:
339 raise RuntimeError("HuggingFace client not initialized")
340
341 if if_match is not None:
342 raise ValueError(
343 "HuggingFace provider does not support conditional deletion. if_match parameter is not supported."
344 )
345
346 path = self._normalize_path(path)
347
348 try:
349 self._hf_client.delete_file(
350 path_in_repo=path,
351 repo_id=self._repository_id,
352 repo_type=self._repo_type,
353 revision=self._repo_revision,
354 commit_message=f"Delete {path}",
355 )
356
357 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
358 raise FileNotFoundError(
359 f"Repository or revision not found: {self._repository_id}@{self._repo_revision}"
360 ) from e
361 except Exception as e:
362 raise RuntimeError(f"Unexpected error during deletion of {path}: {e}") from e
363
364 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
365 """
366 Retrieves metadata for an object in the HuggingFace repository.
367
368 :param path: The path of the object to get metadata for.
369 :param strict: Whether to raise an error if the object doesn't exist.
370 :return: Metadata about the object.
371 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur.
372 :raises FileNotFoundError: If the file doesn't exist and strict=True.
373 """
374 if not self._hf_client:
375 raise RuntimeError("HuggingFace client not initialized")
376
377 path = self._normalize_path(path)
378
379 try:
380 items = self._hf_client.get_paths_info(
381 repo_id=self._repository_id,
382 paths=[path],
383 repo_type=self._repo_type,
384 revision=self._repo_revision,
385 expand=True,
386 )
387
388 if not items:
389 raise FileNotFoundError(f"File not found in HuggingFace repository: {path}")
390
391 item = items[0]
392
393 last_modified = AWARE_DATETIME_MIN
394 if hasattr(item, "last_commit") and item.last_commit:
395 last_modified = item.last_commit.date
396
397 return ObjectMetadata(
398 key=item.path,
399 type="file" if isinstance(item, RepoFile) else "directory",
400 content_length=item.size if isinstance(item, RepoFile) else 0,
401 last_modified=last_modified,
402 etag=None, # Can be obtained via a separate call to get_hf_file_metadata
403 content_type=None,
404 storage_class=None,
405 metadata=None,
406 )
407 except FileNotFoundError as error:
408 if strict:
409 dir_path = path.rstrip("/") + "/"
410 if self._is_dir(dir_path):
411 return ObjectMetadata(
412 key=dir_path,
413 type="directory",
414 content_length=0,
415 last_modified=AWARE_DATETIME_MIN,
416 etag=None,
417 content_type=None,
418 storage_class=None,
419 metadata=None,
420 )
421 raise error
422 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
423 raise FileNotFoundError(
424 f"Repository or revision not found: {self._repository_id}@{self._repo_revision}"
425 ) from e
426 except Exception as e:
427 raise RuntimeError(f"Unexpected error getting metadata for {path}: {e}") from e
428
429 def _list_objects(
430 self,
431 path: str,
432 start_after: Optional[str] = None,
433 end_at: Optional[str] = None,
434 include_directories: bool = False,
435 follow_symlinks: bool = True,
436 ) -> Iterator[ObjectMetadata]:
437 """
438 Lists objects in the HuggingFace repository under the specified path.
439
440 :param path: The path to list objects under.
441 :param start_after: The key to start listing after (not supported by HuggingFace).
442 :param end_at: The key to end listing at (not supported by HuggingFace).
443 :param include_directories: Whether to include directories in the listing.
444 :return: An iterator over object metadata for objects under the specified path.
445 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur.
446 :raises ValueError: If start_after or end_at parameters are provided (not supported).
447 """
448 if not self._hf_client:
449 raise RuntimeError("HuggingFace client not initialized")
450
451 if start_after is not None or end_at is not None:
452 raise ValueError(
453 "HuggingFace provider does not support pagination with start_after or end_at parameters. "
454 "These parameters are not supported by the HuggingFace Hub API."
455 )
456
457 path = self._normalize_path(path)
458
459 try:
460 repo_items = self._hf_client.list_repo_tree(
461 repo_id=self._repository_id,
462 path_in_repo=os.path.dirname(path),
463 repo_type=self._repo_type,
464 revision=self._repo_revision,
465 expand=True,
466 recursive=True,
467 )
468
469 for item in repo_items:
470 if not item.path.startswith(os.path.dirname(path)):
471 continue
472
473 if include_directories and isinstance(item, RepoFolder):
474 last_modified = AWARE_DATETIME_MIN
475 if hasattr(item, "last_commit") and item.last_commit:
476 last_modified = item.last_commit.date
477
478 yield ObjectMetadata(
479 key=item.path,
480 type="directory",
481 content_length=0,
482 last_modified=last_modified,
483 etag=None,
484 content_type=None,
485 storage_class=None,
486 metadata=None,
487 )
488
489 elif isinstance(item, RepoFile):
490 last_modified = AWARE_DATETIME_MIN
491 if hasattr(item, "last_commit") and item.last_commit:
492 last_modified = item.last_commit.date
493
494 yield ObjectMetadata(
495 key=item.path,
496 type="file",
497 content_length=item.size,
498 last_modified=last_modified,
499 etag=None,
500 content_type=None,
501 storage_class=None,
502 metadata=None,
503 )
504 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
505 raise FileNotFoundError(
506 f"Repository or revision not found: {self._repository_id}@{self._repo_revision}"
507 ) from e
508 # if an entry wasn't found, we can effectively treat this as an empty directory since HF creates/deletes directories implicitly.
509 except EntryNotFoundError:
510 pass
511 except HfHubHTTPError as e:
512 raise RuntimeError(f"HuggingFace API error during listing of {path}: {e}") from e
513 except Exception as e:
514 raise RuntimeError(f"Unexpected error during listing of {path}: {e}") from e
515
516 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
517 """
518 Uploads a file to the HuggingFace repository.
519
520 :param remote_path: The remote path where the file will be stored in the repository.
521 :param f: File path or file object to upload.
522 :param attributes: Optional attributes for the file (not supported by HuggingFace).
523 :return: Data size in bytes.
524 :raises RuntimeError: If HuggingFace client is not initialized or API errors occur.
525 :raises ValueError: If client attempts to create a directory.
526 :raises ValueError: If custom attributes are provided (not supported).
527 """
528 if not self._hf_client:
529 raise RuntimeError("HuggingFace client not initialized")
530
531 if attributes is not None:
532 raise ValueError(
533 "HuggingFace provider does not support custom file attributes. "
534 "Use commit messages or repository metadata instead."
535 )
536
537 if remote_path.endswith("/"):
538 raise ValueError(
539 "HuggingFace Storage Provider does not support explicit directory creation. "
540 "Directories are created implicitly when files are uploaded to paths within them."
541 )
542
543 remote_path = self._normalize_path(remote_path)
544
545 try:
546 if isinstance(f, str):
547 file_size = os.path.getsize(f)
548
549 self._hf_client.upload_file(
550 path_or_fileobj=f,
551 path_in_repo=remote_path,
552 repo_id=self._repository_id,
553 repo_type=self._repo_type,
554 revision=self._repo_revision,
555 commit_message=f"Upload {remote_path}",
556 commit_description=None,
557 create_pr=False,
558 )
559
560 return file_size
561
562 else:
563 content = f.read()
564
565 if isinstance(content, str):
566 content_bytes = content.encode("utf-8")
567 else:
568 content_bytes = content
569
570 # Create temporary file since HfAPI.upload_file requires BinaryIO, not generic IO
571 with tempfile.NamedTemporaryFile(delete=False) as temp_file:
572 temp_file.write(content_bytes)
573 temp_file_path = temp_file.name
574
575 try:
576 self._hf_client.upload_file(
577 path_or_fileobj=temp_file_path,
578 path_in_repo=remote_path,
579 repo_id=self._repository_id,
580 repo_type=self._repo_type,
581 revision=self._repo_revision,
582 commit_message=f"Upload {remote_path}",
583 create_pr=False,
584 )
585
586 return len(content_bytes)
587
588 finally:
589 os.unlink(temp_file_path)
590
591 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
592 raise FileNotFoundError(
593 f"Repository or revision not found: {self._repository_id}@{self._repo_revision}"
594 ) from e
595 except HfHubHTTPError as e:
596 raise RuntimeError(f"HuggingFace API error during upload of {remote_path}: {e}") from e
597 except Exception as e:
598 raise RuntimeError(f"Unexpected error during upload of {remote_path}: {e}") from e
599
600 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
601 """
602 Downloads a file from the HuggingFace repository.
603
604 :param remote_path: The remote path of the file to download from the repository.
605 :param f: Local file path or file object to write to.
606 :param metadata: Optional object metadata (not used in this implementation).
607 :return: Data size in bytes.
608 """
609 if not self._hf_client:
610 raise RuntimeError("HuggingFace client not initialized")
611
612 remote_path = self._normalize_path(remote_path)
613
614 try:
615 if isinstance(f, str):
616 parent_dir = os.path.dirname(f)
617 if parent_dir:
618 os.makedirs(parent_dir, exist_ok=True)
619
620 target_dir = parent_dir if parent_dir else "."
621 downloaded_path = self._hf_client.hf_hub_download(
622 repo_id=self._repository_id,
623 filename=remote_path,
624 repo_type=self._repo_type,
625 revision=self._repo_revision,
626 local_dir=target_dir,
627 )
628
629 if os.path.abspath(downloaded_path) != os.path.abspath(f):
630 os.rename(downloaded_path, f)
631
632 return os.path.getsize(f)
633
634 else:
635 with tempfile.TemporaryDirectory() as temp_dir:
636 downloaded_path = self._hf_client.hf_hub_download(
637 repo_id=self._repository_id,
638 filename=remote_path,
639 repo_type=self._repo_type,
640 revision=self._repo_revision,
641 local_dir=temp_dir,
642 )
643
644 with open(downloaded_path, "rb") as src:
645 data = src.read()
646 if isinstance(f, io.TextIOBase):
647 f.write(data.decode("utf-8"))
648 else:
649 f.write(data)
650
651 return len(data)
652
653 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
654 raise FileNotFoundError(f"File not found in HuggingFace repository: {remote_path}") from e
655 except HfHubHTTPError as e:
656 raise RuntimeError(f"HuggingFace API error during download: {e}") from e
657 except Exception as e:
658 raise RuntimeError(f"Unexpected error during download: {e}") from e
659
660 def _is_dir(self, path: str) -> bool:
661 """
662 Helper method to check if a path is a directory.
663
664 :param path: The path to check.
665 :return: True if the path appears to be a directory (has files under it).
666 """
667 path = path.rstrip("/")
668 if not path:
669 # The root of the repo is always a directory
670 return True
671
672 try:
673 path_info = self._hf_client.get_paths_info(
674 repo_id=self._repository_id,
675 paths=[path],
676 repo_type=self._repo_type,
677 revision=self._repo_revision,
678 )[0]
679
680 return isinstance(path_info, RepoFolder)
681
682 except (RepositoryNotFoundError, RevisionNotFoundError) as e:
683 raise FileNotFoundError(
684 f"Repository or revision not found: {self._repository_id}@{self._repo_revision}"
685 ) from e
686 except IndexError:
687 return False
688 except Exception as e:
689 raise Exception(f"Unexpected error: {e}")
690
691 def _normalize_path(self, path: str) -> str:
692 """
693 Normalize path for HuggingFace API by removing leading slashes.
694 HuggingFace expects relative paths within repositories.
695 """
696 return path.lstrip("/")