Source code for multistorageclient.contrib.zarr

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import os
 17from concurrent.futures import ThreadPoolExecutor, as_completed
 18from typing import TYPE_CHECKING, Any, Iterator, Mapping, Optional, Sequence, Tuple
 19
 20import zarr as _zarr
 21from zarr.storage import BaseStore
 22
 23from ..shortcuts import resolve_storage_client
 24from ..types import MSC_PROTOCOL
 25
 26if TYPE_CHECKING:
 27    from ..client import StorageClient
 28
 29
[docs] 30def open_consolidated(*args: Any, **kwargs: Any) -> _zarr.Group: 31 """ 32 Adapt ``zarr.open_consolidated`` to use :py:class:`LazyZarrStore` when path matches the ``msc`` protocol. 33 34 If the path starts with the MSC protocol, it uses :py:class:`LazyZarrStore` with a resolved 35 storage client and prefix, passing ``msc_max_workers`` if provided. Otherwise, it 36 directly calls ``zarr.open_consolidated``. 37 """ 38 args_list = list(args) 39 path = args_list[0] if args_list else kwargs.get("store") 40 msc_max_workers = kwargs.pop("msc_max_workers", None) 41 if isinstance(path, str) and path.startswith(MSC_PROTOCOL): 42 storage_client, prefix = resolve_storage_client(path) 43 zarr_store = LazyZarrStore(storage_client, prefix=prefix, msc_max_workers=msc_max_workers) 44 if path == args_list[0]: 45 args_list[0] = zarr_store 46 else: 47 kwargs["store"] = zarr_store 48 return _zarr.open_consolidated(*args_list, **kwargs) # pyright: ignore [reportReturnType] 49 return _zarr.open_consolidated(*args, **kwargs) # pyright: ignore [reportReturnType]
50 51 52# pyright: reportIncompatibleMethodOverride=false
[docs] 53class LazyZarrStore(BaseStore): 54 def __init__( 55 self, storage_client: "StorageClient", prefix: str = "", msc_max_workers: Optional[int] = None 56 ) -> None: 57 self.storage_client = storage_client 58 self.prefix = prefix 59 self.max_workers = msc_max_workers or int(os.getenv("MSC_MAX_WORKERS", "8")) 60 61 def __getitem__(self, key: str) -> Any: 62 full_key = self.prefix + key 63 return self.storage_client.read(full_key) 64
[docs] 65 def getitems(self, keys: Sequence[str], *, contexts: Any) -> Mapping[str, Any]: 66 def get_item(key: str) -> Tuple[str, Any]: 67 return key, self.__getitem__(key) 68 69 with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 70 futures = {executor.submit(get_item, key): key for key in keys} 71 results = {} 72 for future in as_completed(futures): 73 key, value = future.result() 74 results[key] = value 75 return results
76 77 def __setitem__(self, key: str, value: Any) -> None: 78 full_key = self.prefix + key 79 self.storage_client.write(full_key, value) 80 81 def __delitem__(self, key: str) -> None: 82 full_key = self.prefix + key 83 self.storage_client.delete(full_key) 84 85 def __contains__(self, key: str) -> bool: 86 full_key = self.prefix + key 87 try: 88 self.storage_client.info(full_key) 89 return True 90 except Exception: 91 return False 92
[docs] 93 def keys(self) -> Iterator[str]: 94 for object in self.storage_client.list(self.prefix): 95 yield object.key.removeprefix(self.prefix)
96 97 def __iter__(self) -> Iterator[str]: 98 return iter(self.keys()) 99 100 def __len__(self) -> int: 101 return sum(1 for _ in self.keys())