Source code for multistorageclient.contrib.zarr

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import os
 17from concurrent.futures import ThreadPoolExecutor, as_completed
 18from typing import TYPE_CHECKING, Any, Iterator, Mapping, Optional, Sequence, Tuple
 19
 20import zarr as _zarr
 21from zarr.storage import BaseStore
 22
 23from ..shortcuts import resolve_storage_client
 24from ..types import MSC_PROTOCOL
 25import numpy as np
 26
 27if TYPE_CHECKING:
 28    from ..client import StorageClient
 29
 30
[docs] 31def open_consolidated(*args: Any, **kwargs: Any) -> _zarr.Group: 32 """ 33 Adapt ``zarr.open_consolidated`` to use :py:class:`LazyZarrStore` when path matches the ``msc`` protocol. 34 35 If the path starts with the MSC protocol, it uses :py:class:`LazyZarrStore` with a resolved 36 storage client and prefix, passing ``msc_max_workers`` if provided. Otherwise, it 37 directly calls ``zarr.open_consolidated``. 38 """ 39 args_list = list(args) 40 path = args_list[0] if args_list else kwargs.get("store") 41 msc_max_workers = kwargs.pop("msc_max_workers", None) 42 if isinstance(path, str) and path.startswith(MSC_PROTOCOL): 43 storage_client, prefix = resolve_storage_client(path) 44 zarr_store = LazyZarrStore(storage_client, prefix=prefix, msc_max_workers=msc_max_workers) 45 if path == args_list[0]: 46 args_list[0] = zarr_store 47 else: 48 kwargs["store"] = zarr_store 49 return _zarr.open_consolidated(*args_list, **kwargs) # pyright: ignore [reportReturnType] 50 return _zarr.open_consolidated(*args, **kwargs) # pyright: ignore [reportReturnType]
51 52 53# pyright: reportIncompatibleMethodOverride=false
[docs] 54class LazyZarrStore(BaseStore): 55 def __init__( 56 self, storage_client: "StorageClient", prefix: str = "", msc_max_workers: Optional[int] = None 57 ) -> None: 58 self.storage_client = storage_client 59 self.prefix = prefix 60 self.max_workers = msc_max_workers or int(os.getenv("MSC_MAX_WORKERS", "8")) 61 62 def __getitem__(self, key: str) -> Any: 63 full_key = self.prefix + key 64 return self.storage_client.read(full_key) 65
[docs] 66 def getitems(self, keys: Sequence[str], *, contexts: Any) -> Mapping[str, Any]: 67 def get_item(key: str) -> Tuple[str, Any]: 68 return key, self.__getitem__(key) 69 70 with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 71 futures = {executor.submit(get_item, key): key for key in keys} 72 results = {} 73 for future in as_completed(futures): 74 key, value = future.result() 75 results[key] = value 76 return results
77 78 def __setitem__(self, key: str, value: Any) -> None: 79 full_key = self.prefix + key 80 if isinstance(value, np.ndarray): 81 value = value.tobytes() 82 self.storage_client.write(full_key, value) 83 84 def __delitem__(self, key: str) -> None: 85 full_key = self.prefix + key 86 self.storage_client.delete(full_key) 87 88 def __contains__(self, key: str) -> bool: 89 full_key = self.prefix + key 90 try: 91 self.storage_client.info(full_key) 92 return True 93 except Exception: 94 return False 95
[docs] 96 def keys(self) -> Iterator[str]: 97 for object in self.storage_client.list(self.prefix): 98 yield object.key.removeprefix(self.prefix)
99 100 def __iter__(self) -> Iterator[str]: 101 return iter(self.keys()) 102 103 def __len__(self) -> int: 104 return sum(1 for _ in self.keys())