Source code for multistorageclient.contrib.zarr

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import os
 17from collections.abc import Iterator, Mapping, Sequence
 18from concurrent.futures import ThreadPoolExecutor, as_completed
 19from typing import TYPE_CHECKING, Any, Optional
 20
 21import numpy as np
 22import zarr as _zarr
 23from zarr.storage import BaseStore
 24
 25from ..shortcuts import resolve_storage_client
 26from ..types import MSC_PROTOCOL
 27
 28if TYPE_CHECKING:
 29    from ..client import StorageClient
 30
 31
[docs] 32def open_consolidated(*args: Any, **kwargs: Any) -> _zarr.Group: 33 """ 34 Adapt ``zarr.open_consolidated`` to use :py:class:`LazyZarrStore` when path matches the ``msc`` protocol. 35 36 If the path starts with the MSC protocol, it uses :py:class:`LazyZarrStore` with a resolved 37 storage client and prefix, passing ``msc_max_workers`` if provided. Otherwise, it 38 directly calls ``zarr.open_consolidated``. 39 """ 40 args_list = list(args) 41 path = args_list[0] if args_list else kwargs.get("store") 42 msc_max_workers = kwargs.pop("msc_max_workers", None) 43 if isinstance(path, str) and path.startswith(MSC_PROTOCOL): 44 storage_client, prefix = resolve_storage_client(path) 45 zarr_store = LazyZarrStore(storage_client, prefix=prefix, msc_max_workers=msc_max_workers) 46 if path == args_list[0]: 47 args_list[0] = zarr_store 48 else: 49 kwargs["store"] = zarr_store 50 return _zarr.open_consolidated(*args_list, **kwargs) # pyright: ignore [reportReturnType] 51 return _zarr.open_consolidated(*args, **kwargs) # pyright: ignore [reportReturnType]
52 53 54# pyright: reportIncompatibleMethodOverride=false
[docs] 55class LazyZarrStore(BaseStore): 56 def __init__( 57 self, storage_client: "StorageClient", prefix: str = "", msc_max_workers: Optional[int] = None 58 ) -> None: 59 self.storage_client = storage_client 60 self.prefix = prefix 61 self.max_workers = msc_max_workers or int(os.getenv("MSC_MAX_WORKERS", "8")) 62 63 def __getitem__(self, key: str) -> Any: 64 full_key = self.prefix + key 65 return self.storage_client.read(full_key) 66
[docs] 67 def getitems(self, keys: Sequence[str], *, contexts: Any) -> Mapping[str, Any]: 68 def get_item(key: str) -> tuple[str, Any]: 69 return key, self.__getitem__(key) 70 71 with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 72 futures = {executor.submit(get_item, key): key for key in keys} 73 results = {} 74 for future in as_completed(futures): 75 key, value = future.result() 76 results[key] = value 77 return results
78 79 def __setitem__(self, key: str, value: Any) -> None: 80 full_key = self.prefix + key 81 if isinstance(value, np.ndarray): 82 value = value.tobytes() 83 self.storage_client.write(full_key, value) 84 85 def __delitem__(self, key: str) -> None: 86 full_key = self.prefix + key 87 self.storage_client.delete(full_key) 88 89 def __contains__(self, key: str) -> bool: 90 full_key = self.prefix + key 91 try: 92 self.storage_client.info(full_key) 93 return True 94 except Exception: 95 return False 96
[docs] 97 def keys(self) -> Iterator[str]: 98 for object in self.storage_client.list(self.prefix): 99 yield object.key.removeprefix(self.prefix)
100 101 def __iter__(self) -> Iterator[str]: 102 return iter(self.keys()) 103 104 def __len__(self) -> int: 105 return sum(1 for _ in self.keys())