diff --git a/changes/3644.feature.md b/changes/3644.feature.md new file mode 100644 index 0000000000..196676d7aa --- /dev/null +++ b/changes/3644.feature.md @@ -0,0 +1,5 @@ +The `Store.get` and `Store.get_partial_values` methods now accept `None` as the `prototype` argument. When `prototype` is `None`, stores will use their default buffer class (typically `zarr.core.buffer.cpu.Buffer`). This simplifies the API for common use cases where the default buffer is sufficient. + +A new type alias `BufferClassLike` has been added, which accepts either a `Buffer` class or a `BufferPrototype` instance. + +**Breaking change for third-party store implementations:** If you have implemented a custom `Store` subclass, you must update your `get` and `get_partial_values` methods to handle `prototype=None`. To do this, override the `_get_default_buffer_class` method to return an appropriate default `Buffer` class, and update your method signatures to accept `BufferClassLike | None` instead of `BufferPrototype`. When `prototype` is `None`, call `self._get_default_buffer_class()` to obtain the buffer class. If `prototype` is a `BufferPrototype` instance, extract the buffer class via `prototype.buffer`. diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 87df89a683..d58d36e0f4 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -7,16 +7,19 @@ from itertools import starmap from typing import TYPE_CHECKING, Literal, Protocol, runtime_checkable +from zarr.core.buffer import Buffer, BufferPrototype from zarr.core.sync import sync +from zarr.registry import get_buffer_class if TYPE_CHECKING: from collections.abc import AsyncGenerator, AsyncIterator, Iterable from types import TracebackType from typing import Any, Self, TypeAlias - from zarr.core.buffer import Buffer, BufferPrototype +__all__ = ["BufferClassLike", "ByteGetter", "ByteSetter", "Store", "set_or_delete"] -__all__ = ["ByteGetter", "ByteSetter", "Store", "set_or_delete"] +BufferClassLike = type[Buffer] | BufferPrototype +"""An object that is or contains a Buffer class""" @dataclass @@ -183,11 +186,17 @@ def __eq__(self, value: object) -> bool: """Equality comparison.""" ... + def _get_default_buffer_class(self) -> type[Buffer]: + """ + Get the default buffer class. + """ + return get_buffer_class() + @abstractmethod async def get( self, key: str, - prototype: BufferPrototype, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: """Retrieve the value associated with a given key. @@ -195,8 +204,12 @@ async def get( Parameters ---------- key : str - prototype : BufferPrototype - The prototype of the output buffer. Stores may support a default buffer prototype. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer class for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional ByteRequest may be one of the following. If not provided, all data associated with the key is retrieved. - RangeByteRequest(int, int): Request a specific range of bytes in the form (start, end). The end is exclusive. If the given range is zero-length or starts after the end of the object, an error will be returned. Additionally, if the range ends after the end of the object, the entire remainder of the object will be returned. Otherwise, the exact requested range will be returned. @@ -210,7 +223,11 @@ async def get( ... async def _get_bytes( - self, key: str, *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, + key: str, + *, + prototype: BufferClassLike | None = None, + byte_range: ByteRequest | None = None, ) -> bytes: """ Retrieve raw bytes from the store asynchronously. @@ -222,8 +239,12 @@ async def _get_bytes( ---------- key : str The key identifying the data to retrieve. - prototype : BufferPrototype - The buffer prototype to use for reading the data. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer prototype for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. @@ -248,7 +269,7 @@ async def _get_bytes( -------- >>> store = await MemoryStore.open() >>> await store.set("data", Buffer.from_bytes(b"hello world")) - >>> data = await store.get_bytes("data", prototype=default_buffer_prototype()) + >>> data = await store._get_bytes("data") >>> print(data) b'hello world' """ @@ -258,7 +279,11 @@ async def _get_bytes( return buffer.to_bytes() def _get_bytes_sync( - self, key: str = "", *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, + key: str = "", + *, + prototype: BufferClassLike | None = None, + byte_range: ByteRequest | None = None, ) -> bytes: """ Retrieve raw bytes from the store synchronously. @@ -271,8 +296,12 @@ def _get_bytes_sync( ---------- key : str, optional The key identifying the data to retrieve. Defaults to an empty string. - prototype : BufferPrototype - The buffer prototype to use for reading the data. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer prototype for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. @@ -301,7 +330,7 @@ def _get_bytes_sync( -------- >>> store = MemoryStore() >>> await store.set("data", Buffer.from_bytes(b"hello world")) - >>> data = store.get_bytes_sync("data", prototype=default_buffer_prototype()) + >>> data = store._get_bytes_sync("data") >>> print(data) b'hello world' """ @@ -309,7 +338,11 @@ def _get_bytes_sync( return sync(self._get_bytes(key, prototype=prototype, byte_range=byte_range)) async def _get_json( - self, key: str, *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, + key: str, + *, + prototype: BufferClassLike | None = None, + byte_range: ByteRequest | None = None, ) -> Any: """ Retrieve and parse JSON data from the store asynchronously. @@ -321,8 +354,12 @@ async def _get_json( ---------- key : str The key identifying the JSON data to retrieve. - prototype : BufferPrototype - The buffer prototype to use for reading the data. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer prototype for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. @@ -351,7 +388,7 @@ async def _get_json( >>> store = await MemoryStore.open() >>> metadata = {"zarr_format": 3, "node_type": "array"} >>> await store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) - >>> data = await store.get_json("zarr.json", prototype=default_buffer_prototype()) + >>> data = await store._get_json("zarr.json") >>> print(data) {'zarr_format': 3, 'node_type': 'array'} """ @@ -359,7 +396,11 @@ async def _get_json( return json.loads(await self._get_bytes(key, prototype=prototype, byte_range=byte_range)) def _get_json_sync( - self, key: str = "", *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, + key: str = "", + *, + prototype: BufferClassLike | None = None, + byte_range: ByteRequest | None = None, ) -> Any: """ Retrieve and parse JSON data from the store synchronously. @@ -372,8 +413,12 @@ def _get_json_sync( ---------- key : str, optional The key identifying the JSON data to retrieve. Defaults to an empty string. - prototype : BufferPrototype - The buffer prototype to use for reading the data. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer prototype for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. @@ -407,7 +452,7 @@ def _get_json_sync( >>> store = MemoryStore() >>> metadata = {"zarr_format": 3, "node_type": "array"} >>> store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) - >>> data = store.get_json_sync("zarr.json", prototype=default_buffer_prototype()) + >>> data = store._get_json_sync("zarr.json") >>> print(data) {'zarr_format': 3, 'node_type': 'array'} """ @@ -417,15 +462,19 @@ def _get_json_sync( @abstractmethod async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferClassLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: """Retrieve possibly partial values from given key_ranges. Parameters ---------- - prototype : BufferPrototype - The prototype of the output buffer. Stores may support a default buffer prototype. + prototype : BufferLike | None + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer prototype for this store will be retrieved via the + ``_get_default_buffer_class`` method. key_ranges : Iterable[tuple[str, tuple[int | None, int | None]]] Ordered set of key, range pairs, a key may occur multiple times with different ranges @@ -597,7 +646,7 @@ def close(self) -> None: self._is_open = False async def _get_many( - self, requests: Iterable[tuple[str, BufferPrototype, ByteRequest | None]] + self, requests: Iterable[tuple[str, BufferClassLike | None, ByteRequest | None]] ) -> AsyncGenerator[tuple[str, Buffer | None], None]: """ Retrieve a collection of objects from storage. In general this method does not guarantee @@ -628,10 +677,8 @@ async def getsize(self, key: str) -> int: # Note to implementers: this default implementation is very inefficient since # it requires reading the entire object. Many systems will have ways to get the # size of an object without reading it. - # avoid circular import - from zarr.core.buffer.core import default_buffer_prototype - value = await self.get(key, prototype=default_buffer_prototype()) + value = await self.get(key) if value is None: raise FileNotFoundError(key) return len(value) diff --git a/src/zarr/core/_info.py b/src/zarr/core/_info.py index fef424346a..0d8ea8fe70 100644 --- a/src/zarr/core/_info.py +++ b/src/zarr/core/_info.py @@ -69,7 +69,7 @@ def byte_info(size: int) -> str: @dataclasses.dataclass(kw_only=True, frozen=True, slots=True) -class ArrayInfo: +class ArrayInfo: # type: ignore[misc] """ Visual summary for an Array. diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index d38949657e..9a2bc2d2f3 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -23,12 +23,15 @@ from typing_extensions import ReadOnly +from zarr.core.buffer import Buffer, BufferPrototype from zarr.core.config import config as zarr_config from zarr.errors import ZarrRuntimeWarning if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Iterator + from zarr.abc.store import BufferClassLike + ZARR_JSON = "zarr.json" ZARRAY_JSON = ".zarray" @@ -246,3 +249,17 @@ def _warn_order_kwarg() -> None: def _default_zarr_format() -> ZarrFormat: """Return the default zarr_version""" return cast("ZarrFormat", int(zarr_config.get("default_zarr_format", 3))) + + +def parse_bufferclasslike(obj: BufferClassLike | None) -> type[Buffer]: + """ + Take an optional BufferClassLike and return a Buffer class + """ + # Avoid a circular import. Temporary fix until we re-organize modules appropriately. + from zarr.registry import get_buffer_class + + if obj is None: + return get_buffer_class() + if isinstance(obj, BufferPrototype): + return obj.buffer + return obj diff --git a/src/zarr/core/dtype/common.py b/src/zarr/core/dtype/common.py index 6b70f595ba..c1590e7c46 100644 --- a/src/zarr/core/dtype/common.py +++ b/src/zarr/core/dtype/common.py @@ -125,7 +125,7 @@ def check_dtype_spec_v2(data: object) -> TypeGuard[DTypeSpec_V2]: DTypeSpec_V3 = str | NamedConfig[str, Mapping[str, object]] -def check_dtype_spec_v3(data: object) -> TypeGuard[DTypeSpec_V3]: +def check_dtype_spec_v3(data: object) -> TypeGuard[DTypeSpec_V3]: # type: ignore[valid-type] """ Type guard for narrowing the type of a python object to an instance of DTypeSpec_V3, i.e either a string or a dict with a "name" field that's a string and a @@ -141,7 +141,7 @@ def check_dtype_spec_v3(data: object) -> TypeGuard[DTypeSpec_V3]: return False -def unpack_dtype_json(data: DTypeSpec_V2 | DTypeSpec_V3) -> DTypeJSON: +def unpack_dtype_json(data: DTypeSpec_V2 | DTypeSpec_V3) -> DTypeJSON: # type: ignore[valid-type] """ Return the array metadata form of the dtype JSON representation. For the Zarr V3 form of dtype metadata, this is a no-op. For the Zarr V2 form of dtype metadata, this unpacks the dtype name. diff --git a/src/zarr/core/dtype/npy/structured.py b/src/zarr/core/dtype/npy/structured.py index 8bedee07ef..09c9c5b135 100644 --- a/src/zarr/core/dtype/npy/structured.py +++ b/src/zarr/core/dtype/npy/structured.py @@ -317,7 +317,7 @@ def to_json(self, zarr_format: ZarrFormat) -> StructuredJSON_V2 | StructuredJSON elif zarr_format == 3: v3_unstable_dtype_warning(self) fields = [ - [f_name, f_dtype.to_json(zarr_format=zarr_format)] # type: ignore[list-item] + [f_name, f_dtype.to_json(zarr_format=zarr_format)] for f_name, f_dtype in self.fields ] base_dict = { diff --git a/src/zarr/core/dtype/wrapper.py b/src/zarr/core/dtype/wrapper.py index fdc5f747f0..9fe41f8119 100644 --- a/src/zarr/core/dtype/wrapper.py +++ b/src/zarr/core/dtype/wrapper.py @@ -57,7 +57,7 @@ @dataclass(frozen=True, kw_only=True, slots=True) -class ZDType(ABC, Generic[TDType_co, TScalar_co]): +class ZDType(ABC, Generic[TDType_co, TScalar_co]): # type: ignore[misc] """ Abstract base class for wrapping native array data types, e.g. numpy dtypes @@ -169,10 +169,10 @@ def from_json(cls: type[Self], data: DTypeJSON, *, zarr_format: ZarrFormat) -> S def to_json(self, zarr_format: Literal[2]) -> DTypeSpec_V2: ... @overload - def to_json(self, zarr_format: Literal[3]) -> DTypeSpec_V3: ... + def to_json(self, zarr_format: Literal[3]) -> DTypeSpec_V3: ... # type: ignore[valid-type] @abstractmethod - def to_json(self, zarr_format: ZarrFormat) -> DTypeSpec_V2 | DTypeSpec_V3: + def to_json(self, zarr_format: ZarrFormat) -> DTypeSpec_V2 | DTypeSpec_V3: # type: ignore[valid-type] """ Serialize this ZDType to JSON. diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 9b5fee275b..fadb700fd8 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -17,7 +17,7 @@ import zarr.api.asynchronous as async_api from zarr.abc.metadata import Metadata -from zarr.abc.store import Store, set_or_delete +from zarr.abc.store import BufferClassLike, Store, set_or_delete from zarr.core._info import GroupInfo from zarr.core.array import ( DEFAULT_FILL_VALUE, @@ -32,7 +32,6 @@ create_array, ) from zarr.core.attributes import Attributes -from zarr.core.buffer import default_buffer_prototype from zarr.core.common import ( JSON, ZARR_JSON, @@ -44,6 +43,7 @@ NodeType, ShapeLike, ZarrFormat, + parse_bufferclasslike, parse_shapelike, ) from zarr.core.config import config @@ -75,7 +75,7 @@ from typing import Any from zarr.core.array_spec import ArrayConfigLike - from zarr.core.buffer import Buffer, BufferPrototype + from zarr.core.buffer import Buffer from zarr.core.chunk_key_encodings import ChunkKeyEncodingLike from zarr.core.common import MemoryOrder from zarr.core.dtype import ZDTypeLike @@ -356,20 +356,25 @@ class GroupMetadata(Metadata): consolidated_metadata: ConsolidatedMetadata | None = None node_type: Literal["group"] = field(default="group", init=False) - def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: + def to_buffer_dict(self, prototype: BufferClassLike | None = None) -> dict[str, Buffer]: + """ + Convert the metadata document to a dict with string keys and `Buffer` values. + """ + buffer_cls = parse_bufferclasslike(prototype) + json_indent = config.get("json_indent") if self.zarr_format == 3: return { - ZARR_JSON: prototype.buffer.from_bytes( + ZARR_JSON: buffer_cls.from_bytes( json.dumps(self.to_dict(), indent=json_indent, allow_nan=True).encode() ) } else: items = { - ZGROUP_JSON: prototype.buffer.from_bytes( + ZGROUP_JSON: buffer_cls.from_bytes( json.dumps({"zarr_format": self.zarr_format}, indent=json_indent).encode() ), - ZATTRS_JSON: prototype.buffer.from_bytes( + ZATTRS_JSON: buffer_cls.from_bytes( json.dumps(self.attributes, indent=json_indent, allow_nan=True).encode() ), } @@ -396,7 +401,7 @@ def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: }, } - items[ZMETADATA_V2_JSON] = prototype.buffer.from_bytes( + items[ZMETADATA_V2_JSON] = buffer_cls.from_bytes( json.dumps( {"metadata": d, "zarr_consolidated_format": 1}, allow_nan=True ).encode() @@ -2029,7 +2034,7 @@ async def update_attributes_async(self, new_attributes: dict[str, Any]) -> Group new_metadata = replace(self.metadata, attributes=new_attributes) # Write new metadata - to_save = new_metadata.to_buffer_dict(default_buffer_prototype()) + to_save = new_metadata.to_buffer_dict() awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()] await asyncio.gather(*awaitables) @@ -3615,9 +3620,7 @@ async def _read_metadata_v3(store: Store, path: str) -> ArrayV3Metadata | GroupM document stored at store_path.path / zarr.json. If no such document is found, raise a FileNotFoundError. """ - zarr_json_bytes = await store.get( - _join_paths([path, ZARR_JSON]), prototype=default_buffer_prototype() - ) + zarr_json_bytes = await store.get(_join_paths([path, ZARR_JSON])) if zarr_json_bytes is None: raise FileNotFoundError(path) else: @@ -3634,9 +3637,9 @@ async def _read_metadata_v2(store: Store, path: str) -> ArrayV2Metadata | GroupM # TODO: consider first fetching array metadata, and only fetching group metadata when we don't # find an array zarray_bytes, zgroup_bytes, zattrs_bytes = await asyncio.gather( - store.get(_join_paths([path, ZARRAY_JSON]), prototype=default_buffer_prototype()), - store.get(_join_paths([path, ZGROUP_JSON]), prototype=default_buffer_prototype()), - store.get(_join_paths([path, ZATTRS_JSON]), prototype=default_buffer_prototype()), + store.get(_join_paths([path, ZARRAY_JSON])), + store.get(_join_paths([path, ZGROUP_JSON])), + store.get(_join_paths([path, ZATTRS_JSON])), ) if zattrs_bytes is None: @@ -3850,7 +3853,7 @@ def _persist_metadata( Prepare to save a metadata document to storage, returning a tuple of coroutines that must be awaited. """ - to_save = metadata.to_buffer_dict(default_buffer_prototype()) + to_save = metadata.to_buffer_dict() return tuple( _set_return_key(store=store, key=_join_paths([path, key]), value=value, semaphore=semaphore) for key, value in to_save.items() diff --git a/src/zarr/core/metadata/io.py b/src/zarr/core/metadata/io.py index 7b63f5493b..0a04063cd7 100644 --- a/src/zarr/core/metadata/io.py +++ b/src/zarr/core/metadata/io.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING from zarr.abc.store import set_or_delete -from zarr.core.buffer.core import default_buffer_prototype from zarr.errors import ContainsArrayError from zarr.storage._common import StorePath, ensure_no_existing_node @@ -51,7 +50,7 @@ async def save_metadata( ------ ValueError """ - to_save = metadata.to_buffer_dict(default_buffer_prototype()) + to_save = metadata.to_buffer_dict() set_awaitables = [set_or_delete(store_path / key, value) for key, value in to_save.items()] if ensure_parents: @@ -71,9 +70,7 @@ async def save_metadata( set_awaitables.extend( [ (parent_store_path / key).set_if_not_exists(value) - for key, value in parent_metadata.to_buffer_dict( - default_buffer_prototype() - ).items() + for key, value in parent_metadata.to_buffer_dict().items() ] ) diff --git a/src/zarr/core/metadata/v2.py b/src/zarr/core/metadata/v2.py index 3204543426..153305df32 100644 --- a/src/zarr/core/metadata/v2.py +++ b/src/zarr/core/metadata/v2.py @@ -18,6 +18,7 @@ import numpy.typing as npt + from zarr.abc.store import BufferClassLike from zarr.core.buffer import Buffer, BufferPrototype from zarr.core.dtype.wrapper import ( TBaseDType, @@ -39,6 +40,7 @@ ZARRAY_JSON, ZATTRS_JSON, MemoryOrder, + parse_bufferclasslike, parse_shapelike, ) from zarr.core.config import config, parse_indexing_order @@ -125,15 +127,16 @@ def chunk_grid(self) -> RegularChunkGrid: def shards(self) -> tuple[int, ...] | None: return None - def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: + def to_buffer_dict(self, prototype: BufferClassLike | None = None) -> dict[str, Buffer]: + buffer_cls = parse_bufferclasslike(prototype) zarray_dict = self.to_dict() zattrs_dict = zarray_dict.pop("attributes", {}) json_indent = config.get("json_indent") return { - ZARRAY_JSON: prototype.buffer.from_bytes( + ZARRAY_JSON: buffer_cls.from_bytes( json.dumps(zarray_dict, indent=json_indent, allow_nan=True).encode() ), - ZATTRS_JSON: prototype.buffer.from_bytes( + ZATTRS_JSON: buffer_cls.from_bytes( json.dumps(zattrs_dict, indent=json_indent, allow_nan=True).encode() ), } diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 5ce155bd9a..3d4957c17f 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from typing import Self + from zarr.abc.store import BufferClassLike from zarr.core.buffer import Buffer, BufferPrototype from zarr.core.chunk_grids import ChunkGrid from zarr.core.common import JSON @@ -35,6 +36,7 @@ ZARR_JSON, DimensionNames, NamedConfig, + parse_bufferclasslike, parse_named_configuration, parse_shapelike, ) @@ -345,11 +347,12 @@ def get_chunk_spec( def encode_chunk_key(self, chunk_coords: tuple[int, ...]) -> str: return self.chunk_key_encoding.encode_chunk_key(chunk_coords) - def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: + def to_buffer_dict(self, prototype: BufferClassLike | None = None) -> dict[str, Buffer]: + buffer_cls = parse_bufferclasslike(prototype) json_indent = config.get("json_indent") d = self.to_dict() return { - ZARR_JSON: prototype.buffer.from_bytes( + ZARR_JSON: buffer_cls.from_bytes( json.dumps(d, allow_nan=True, indent=json_indent).encode() ) } diff --git a/src/zarr/experimental/cache_store.py b/src/zarr/experimental/cache_store.py index 3456c94320..e55cef16b2 100644 --- a/src/zarr/experimental/cache_store.py +++ b/src/zarr/experimental/cache_store.py @@ -6,13 +6,13 @@ from collections import OrderedDict from typing import TYPE_CHECKING, Any, Literal -from zarr.abc.store import ByteRequest, Store +from zarr.abc.store import BufferClassLike, ByteRequest, Store from zarr.storage._wrapper import WrapperStore logger = logging.getLogger(__name__) if TYPE_CHECKING: - from zarr.core.buffer.core import Buffer, BufferPrototype + from zarr.core.buffer.core import Buffer class CacheStore(WrapperStore[Store]): @@ -218,7 +218,7 @@ def _remove_from_tracking(self, key: str) -> None: self._key_sizes.pop(key, None) async def _get_try_cache( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, key: str, prototype: BufferClassLike | None, byte_range: ByteRequest | None = None ) -> Buffer | None: """Try to get data from cache first, falling back to source store.""" maybe_cached_result = await self._cache.get(key, prototype, byte_range) @@ -246,7 +246,10 @@ async def _get_try_cache( return maybe_fresh_result async def _get_no_cache( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, + key: str, + prototype: BufferClassLike | None = None, + byte_range: ByteRequest | None = None, ) -> Buffer | None: """Get data directly from source store and update cache.""" self._misses += 1 @@ -265,7 +268,7 @@ async def _get_no_cache( async def get( self, key: str, - prototype: BufferPrototype, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: """ @@ -275,8 +278,12 @@ async def get( ---------- key : str The key to retrieve - prototype : BufferPrototype - Buffer prototype for creating the result buffer + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer class for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional Byte range to retrieve diff --git a/src/zarr/storage/_common.py b/src/zarr/storage/_common.py index 4bea04f024..ddb6a38bc7 100644 --- a/src/zarr/storage/_common.py +++ b/src/zarr/storage/_common.py @@ -3,10 +3,10 @@ import importlib.util import json from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Self, TypeAlias +from typing import Any, Literal, Self, TypeAlias -from zarr.abc.store import ByteRequest, Store -from zarr.core.buffer import Buffer, default_buffer_prototype +from zarr.abc.store import BufferClassLike, ByteRequest, Store +from zarr.core.buffer import Buffer from zarr.core.common import ( ANY_ACCESS_MODE, ZARR_JSON, @@ -26,9 +26,6 @@ else: FSMap = None -if TYPE_CHECKING: - from zarr.core.buffer import BufferPrototype - def _dereference_path(root: str, path: str) -> str: if not isinstance(root, str): @@ -145,7 +142,7 @@ async def open(cls, store: Store, path: str, mode: AccessModeLiteral | None = No async def get( self, - prototype: BufferPrototype | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: """ @@ -153,8 +150,12 @@ async def get( Parameters ---------- - prototype : BufferPrototype, optional - The buffer prototype to use when reading the bytes. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer class for this store will be retrieved via the + store's ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional The range of bytes to read. @@ -164,7 +165,7 @@ async def get( The read bytes, or None if the key does not exist. """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self.store._get_default_buffer_class() return await self.store.get(self.path, prototype=prototype, byte_range=byte_range) async def set(self, value: Buffer) -> None: diff --git a/src/zarr/storage/_fsspec.py b/src/zarr/storage/_fsspec.py index f9e4ed375d..faeacacd97 100644 --- a/src/zarr/storage/_fsspec.py +++ b/src/zarr/storage/_fsspec.py @@ -8,6 +8,7 @@ from packaging.version import parse as parse_version from zarr.abc.store import ( + BufferClassLike, ByteRequest, OffsetByteRequest, RangeByteRequest, @@ -15,6 +16,7 @@ SuffixByteRequest, ) from zarr.core.buffer import Buffer +from zarr.core.common import parse_bufferclasslike from zarr.errors import ZarrUserWarning from zarr.storage._common import _dereference_path @@ -25,8 +27,6 @@ from fsspec.asyn import AsyncFileSystem from fsspec.mapping import FSMap - from zarr.core.buffer import BufferPrototype - ALLOWED_EXCEPTIONS: tuple[type[Exception], ...] = ( FileNotFoundError, @@ -276,19 +276,24 @@ def __eq__(self, other: object) -> bool: async def get( self, key: str, - prototype: BufferPrototype, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited if not self._is_open: await self._open() + if prototype is None: + buffer_cls = self._get_default_buffer_class() + else: + buffer_cls = parse_bufferclasslike(prototype) + path = _dereference_path(self.path, key) try: if byte_range is None: - value = prototype.buffer.from_bytes(await self.fs._cat_file(path)) + value = buffer_cls.from_bytes(await self.fs._cat_file(path)) elif isinstance(byte_range, RangeByteRequest): - value = prototype.buffer.from_bytes( + value = buffer_cls.from_bytes( await self.fs._cat_file( path, start=byte_range.start, @@ -296,11 +301,11 @@ async def get( ) ) elif isinstance(byte_range, OffsetByteRequest): - value = prototype.buffer.from_bytes( + value = buffer_cls.from_bytes( await self.fs._cat_file(path, start=byte_range.offset, end=None) ) elif isinstance(byte_range, SuffixByteRequest): - value = prototype.buffer.from_bytes( + value = buffer_cls.from_bytes( await self.fs._cat_file(path, start=-byte_range.suffix, end=None) ) else: @@ -310,7 +315,7 @@ async def get( except OSError as e: if "not satisfiable" in str(e): # this is an s3-specific condition we probably don't want to leak - return prototype.buffer.from_bytes(b"") + return buffer_cls.from_bytes(b"") raise else: return value @@ -367,10 +372,15 @@ async def exists(self, key: str) -> bool: async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferClassLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited + if prototype is None: + buffer_cls = self._get_default_buffer_class() + else: + buffer_cls = parse_bufferclasslike(prototype) + if key_ranges: # _cat_ranges expects a list of paths, start, and end ranges, so we need to reformat each ByteRequest. key_ranges = list(key_ranges) @@ -403,7 +413,7 @@ async def get_partial_values( if isinstance(r, Exception) and not isinstance(r, self.allowed_exceptions): raise r - return [None if isinstance(r, Exception) else prototype.buffer.from_bytes(r) for r in res] + return [None if isinstance(r, Exception) else buffer_cls.from_bytes(r) for r in res] async def list(self) -> AsyncIterator[str]: # docstring inherited diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index 80233a112d..b26f00f8ed 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -11,37 +11,41 @@ from typing import TYPE_CHECKING, Any, BinaryIO, Literal, Self from zarr.abc.store import ( + BufferClassLike, ByteRequest, OffsetByteRequest, RangeByteRequest, Store, SuffixByteRequest, ) -from zarr.core.buffer import Buffer -from zarr.core.buffer.core import default_buffer_prototype +from zarr.core.buffer import Buffer, BufferPrototype from zarr.core.common import AccessModeLiteral, concurrent_map if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable, Iterator - from zarr.core.buffer import BufferPrototype +def _get(path: Path, prototype: BufferClassLike, byte_range: ByteRequest | None) -> Buffer: + # Extract buffer class from BufferLike + if isinstance(prototype, BufferPrototype): + buffer_cls = prototype.buffer + else: + buffer_cls = prototype -def _get(path: Path, prototype: BufferPrototype, byte_range: ByteRequest | None) -> Buffer: if byte_range is None: - return prototype.buffer.from_bytes(path.read_bytes()) + return buffer_cls.from_bytes(path.read_bytes()) with path.open("rb") as f: size = f.seek(0, io.SEEK_END) if isinstance(byte_range, RangeByteRequest): f.seek(byte_range.start) - return prototype.buffer.from_bytes(f.read(byte_range.end - f.tell())) + return buffer_cls.from_bytes(f.read(byte_range.end - f.tell())) elif isinstance(byte_range, OffsetByteRequest): f.seek(byte_range.offset) elif isinstance(byte_range, SuffixByteRequest): f.seek(max(0, size - byte_range.suffix)) else: raise TypeError(f"Unexpected byte_range, got {byte_range}.") - return prototype.buffer.from_bytes(f.read()) + return buffer_cls.from_bytes(f.read()) if sys.platform == "win32": @@ -190,12 +194,12 @@ def __eq__(self, other: object) -> bool: async def get( self, key: str, - prototype: BufferPrototype | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() if not self._is_open: await self._open() assert isinstance(key, str) @@ -208,10 +212,12 @@ async def get( async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferClassLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited + if prototype is None: + prototype = self._get_default_buffer_class() args = [] for key, byte_range in key_ranges: assert isinstance(key, str) @@ -310,7 +316,7 @@ async def _get_bytes( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> bytes: """ @@ -324,8 +330,8 @@ async def _get_bytes( ---------- key : str, optional The key identifying the data to retrieve. Defaults to an empty string. - prototype : BufferPrototype, optional - The buffer prototype to use for reading the data. If None, uses + prototype : BufferClassLike, optional + A specification of the buffer class to use for reading the data. If None, uses ``default_buffer_prototype()``. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. @@ -350,19 +356,19 @@ async def _get_bytes( >>> store = await LocalStore.open("data") >>> await store.set("data", Buffer.from_bytes(b"hello")) >>> # No need to specify prototype for LocalStore - >>> data = await store.get_bytes("data") + >>> data = await store._get_bytes("data") >>> print(data) b'hello' """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return await super()._get_bytes(key, prototype=prototype, byte_range=byte_range) def _get_bytes_sync( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> bytes: """ @@ -376,8 +382,8 @@ def _get_bytes_sync( ---------- key : str, optional The key identifying the data to retrieve. Defaults to an empty string. - prototype : BufferPrototype, optional - The buffer prototype to use for reading the data. If None, uses + prototype : BufferClassLike, optional + A specification of the buffer class to use for reading the data. If None, uses ``default_buffer_prototype()``. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. @@ -406,19 +412,19 @@ def _get_bytes_sync( >>> store = LocalStore("data") >>> store.set("data", Buffer.from_bytes(b"hello")) >>> # No need to specify prototype for LocalStore - >>> data = store.get_bytes("data") + >>> data = store._get_bytes_sync("data") >>> print(data) b'hello' """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return super()._get_bytes_sync(key, prototype=prototype, byte_range=byte_range) async def _get_json( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Any: """ @@ -432,8 +438,8 @@ async def _get_json( ---------- key : str, optional The key identifying the JSON data to retrieve. Defaults to an empty string. - prototype : BufferPrototype, optional - The buffer prototype to use for reading the data. If None, uses + prototype : BufferClassLike, optional + A specification of the buffer class to use for reading the data. If None, uses ``default_buffer_prototype()``. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. @@ -465,19 +471,19 @@ async def _get_json( >>> metadata = {"zarr_format": 3, "node_type": "array"} >>> await store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) >>> # No need to specify prototype for LocalStore - >>> data = await store.get_json("zarr.json") + >>> data = await store._get_json("zarr.json") >>> print(data) {'zarr_format': 3, 'node_type': 'array'} """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return await super()._get_json(key, prototype=prototype, byte_range=byte_range) def _get_json_sync( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Any: """ @@ -491,8 +497,8 @@ def _get_json_sync( ---------- key : str, optional The key identifying the JSON data to retrieve. Defaults to an empty string. - prototype : BufferPrototype, optional - The buffer prototype to use for reading the data. If None, uses + prototype : BufferClassLike, optional + A specification of the buffer class to use for reading the data. If None, uses ``default_buffer_prototype()``. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. @@ -528,12 +534,12 @@ def _get_json_sync( >>> metadata = {"zarr_format": 3, "node_type": "array"} >>> store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) >>> # No need to specify prototype for LocalStore - >>> data = store.get_json("zarr.json") + >>> data = store._get_json_sync("zarr.json") >>> print(data) {'zarr_format': 3, 'node_type': 'array'} """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return super()._get_json_sync(key, prototype=prototype, byte_range=byte_range) async def move(self, dest_root: Path | str) -> None: diff --git a/src/zarr/storage/_logging.py b/src/zarr/storage/_logging.py index dd20d49ae5..b3b5030cd0 100644 --- a/src/zarr/storage/_logging.py +++ b/src/zarr/storage/_logging.py @@ -8,14 +8,14 @@ from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Self, TypeVar -from zarr.abc.store import Store +from zarr.abc.store import BufferClassLike, Store from zarr.storage._wrapper import WrapperStore if TYPE_CHECKING: from collections.abc import AsyncGenerator, Generator, Iterable from zarr.abc.store import ByteRequest - from zarr.core.buffer import Buffer, BufferPrototype + from zarr.core.buffer import Buffer counter: defaultdict[str, int] @@ -165,7 +165,7 @@ def __eq__(self, other: object) -> bool: async def get( self, key: str, - prototype: BufferPrototype, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited @@ -174,7 +174,7 @@ async def get( async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferClassLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index e6f9b7a512..e21c9d7c42 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -3,17 +3,14 @@ from logging import getLogger from typing import TYPE_CHECKING, Any, Self -from zarr.abc.store import ByteRequest, Store -from zarr.core.buffer import Buffer, gpu -from zarr.core.buffer.core import default_buffer_prototype +from zarr.abc.store import BufferClassLike, ByteRequest, Store +from zarr.core.buffer import Buffer, BufferPrototype, gpu from zarr.core.common import concurrent_map from zarr.storage._utils import _normalize_byte_range_index if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable, MutableMapping - from zarr.core.buffer import BufferPrototype - logger = getLogger(__name__) @@ -80,25 +77,30 @@ def __eq__(self, other: object) -> bool: async def get( self, key: str, - prototype: BufferPrototype | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() + # Extract buffer class from BufferLike + if isinstance(prototype, BufferPrototype): + buffer_cls = prototype.buffer + else: + buffer_cls = prototype if not self._is_open: await self._open() assert isinstance(key, str) try: value = self._store_dict[key] start, stop = _normalize_byte_range_index(value, byte_range) - return prototype.buffer.from_buffer(value[start:stop]) + return buffer_cls.from_buffer(value[start:stop]) except KeyError: return None async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferClassLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited @@ -179,7 +181,7 @@ async def _get_bytes( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> bytes: """ @@ -193,8 +195,8 @@ async def _get_bytes( ---------- key : str, optional The key identifying the data to retrieve. Defaults to an empty string. - prototype : BufferPrototype, optional - The buffer prototype to use for reading the data. If None, uses + prototype : BufferClassLike, optional + A specification of the buffer class to use for reading the data. If None, uses ``default_buffer_prototype()``. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. @@ -219,19 +221,19 @@ async def _get_bytes( >>> store = await MemoryStore.open() >>> await store.set("data", Buffer.from_bytes(b"hello")) >>> # No need to specify prototype for MemoryStore - >>> data = await store.get_bytes("data") + >>> data = await store._get_bytes("data") >>> print(data) b'hello' """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return await super()._get_bytes(key, prototype=prototype, byte_range=byte_range) def _get_bytes_sync( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> bytes: """ @@ -245,8 +247,8 @@ def _get_bytes_sync( ---------- key : str, optional The key identifying the data to retrieve. Defaults to an empty string. - prototype : BufferPrototype, optional - The buffer prototype to use for reading the data. If None, uses + prototype : BufferClassLike, optional + A specification of the buffer class to use for reading the data. If None, uses ``default_buffer_prototype()``. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. @@ -275,19 +277,19 @@ def _get_bytes_sync( >>> store = MemoryStore() >>> store.set("data", Buffer.from_bytes(b"hello")) >>> # No need to specify prototype for MemoryStore - >>> data = store.get_bytes("data") + >>> data = store._get_bytes_sync("data") >>> print(data) b'hello' """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return super()._get_bytes_sync(key, prototype=prototype, byte_range=byte_range) async def _get_json( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Any: """ @@ -301,8 +303,8 @@ async def _get_json( ---------- key : str, optional The key identifying the JSON data to retrieve. Defaults to an empty string. - prototype : BufferPrototype, optional - The buffer prototype to use for reading the data. If None, uses + prototype : BufferClassLike, optional + A specification of the buffer class to use for reading the data. If None, uses ``default_buffer_prototype()``. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. @@ -334,19 +336,19 @@ async def _get_json( >>> metadata = {"zarr_format": 3, "node_type": "array"} >>> await store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) >>> # No need to specify prototype for MemoryStore - >>> data = await store.get_json("zarr.json") + >>> data = await store._get_json("zarr.json") >>> print(data) {'zarr_format': 3, 'node_type': 'array'} """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return await super()._get_json(key, prototype=prototype, byte_range=byte_range) def _get_json_sync( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Any: """ @@ -360,8 +362,8 @@ def _get_json_sync( ---------- key : str, optional The key identifying the JSON data to retrieve. Defaults to an empty string. - prototype : BufferPrototype, optional - The buffer prototype to use for reading the data. If None, uses + prototype : BufferClassLike, optional + A specification of the buffer class to use for reading the data. If None, uses ``default_buffer_prototype()``. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. @@ -397,12 +399,12 @@ def _get_json_sync( >>> metadata = {"zarr_format": 3, "node_type": "array"} >>> store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) >>> # No need to specify prototype for MemoryStore - >>> data = store.get_json("zarr.json") + >>> data = store._get_json_sync("zarr.json") >>> print(data) {'zarr_format': 3, 'node_type': 'array'} """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return super()._get_json_sync(key, prototype=prototype, byte_range=byte_range) diff --git a/src/zarr/storage/_obstore.py b/src/zarr/storage/_obstore.py index 5c2197ecf6..f836bb58a1 100644 --- a/src/zarr/storage/_obstore.py +++ b/src/zarr/storage/_obstore.py @@ -7,12 +7,15 @@ from typing import TYPE_CHECKING, Generic, Self, TypedDict, TypeVar from zarr.abc.store import ( + BufferClassLike, ByteRequest, OffsetByteRequest, RangeByteRequest, Store, SuffixByteRequest, ) +from zarr.core.buffer import BufferPrototype +from zarr.core.buffer.core import default_buffer_prototype from zarr.core.common import concurrent_map from zarr.core.config import config @@ -23,7 +26,7 @@ from obstore import ListResult, ListStream, ObjectMeta, OffsetRange, SuffixRange from obstore.store import ObjectStore as _UpstreamObjectStore - from zarr.core.buffer import Buffer, BufferPrototype + from zarr.core.buffer import Buffer __all__ = ["ObjectStore"] @@ -95,25 +98,36 @@ def __setstate__(self, state: dict[Any, Any]) -> None: self.__dict__.update(state) async def get( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, + key: str, + prototype: BufferClassLike | None = None, + byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited import obstore as obs + if prototype is None: + prototype = self._get_default_buffer_class() + # Extract buffer class from BufferLike + if isinstance(prototype, BufferPrototype): + buffer_cls = prototype.buffer + else: + buffer_cls = prototype + try: if byte_range is None: resp = await obs.get_async(self.store, key) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + return buffer_cls.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] elif isinstance(byte_range, RangeByteRequest): bytes = await obs.get_range_async( self.store, key, start=byte_range.start, end=byte_range.end ) - return prototype.buffer.from_bytes(bytes) # type: ignore[arg-type] + return buffer_cls.from_bytes(bytes) # type: ignore[arg-type] elif isinstance(byte_range, OffsetByteRequest): resp = await obs.get_async( self.store, key, options={"range": {"offset": byte_range.offset}} ) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + return buffer_cls.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] elif isinstance(byte_range, SuffixByteRequest): # some object stores (Azure) don't support suffix requests. In this # case, our workaround is to first get the length of the object and then @@ -122,7 +136,7 @@ async def get( resp = await obs.get_async( self.store, key, options={"range": {"suffix": byte_range.suffix}} ) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + return buffer_cls.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] except obs.exceptions.NotSupportedError: head_resp = await obs.head_async(self.store, key) file_size = head_resp["size"] @@ -133,7 +147,7 @@ async def get( start=file_size - suffix_len, length=suffix_len, ) - return prototype.buffer.from_bytes(buffer) # type: ignore[arg-type] + return buffer_cls.from_bytes(buffer) # type: ignore[arg-type] else: raise ValueError(f"Unexpected byte_range, got {byte_range}") except _ALLOWED_EXCEPTIONS: @@ -141,10 +155,16 @@ async def get( async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferClassLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited + if prototype is None: + prototype = self._get_default_buffer_class() + # Extract buffer class from BufferLike - _get_partial_values expects BufferPrototype + if not isinstance(prototype, BufferPrototype): + # Convert raw buffer class to BufferPrototype + prototype = default_buffer_prototype() return await _get_partial_values(self.store, prototype=prototype, key_ranges=key_ranges) async def exists(self, key: str) -> bool: diff --git a/src/zarr/storage/_wrapper.py b/src/zarr/storage/_wrapper.py index 64a5b2d83c..13b75edc2f 100644 --- a/src/zarr/storage/_wrapper.py +++ b/src/zarr/storage/_wrapper.py @@ -9,9 +9,8 @@ from zarr.abc.buffer import Buffer from zarr.abc.store import ByteRequest - from zarr.core.buffer import BufferPrototype -from zarr.abc.store import Store +from zarr.abc.store import BufferClassLike, Store T_Store = TypeVar("T_Store", bound=Store) @@ -85,13 +84,16 @@ def __repr__(self) -> str: return f"WrapperStore({self._store.__class__.__name__}, '{self._store}')" async def get( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, + key: str, + prototype: BufferClassLike | None = None, + byte_range: ByteRequest | None = None, ) -> Buffer | None: return await self._store.get(key, prototype, byte_range) async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferClassLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: return await self._store.get_partial_values(prototype, key_ranges) @@ -139,7 +141,7 @@ def close(self) -> None: self._store.close() async def _get_many( - self, requests: Iterable[tuple[str, BufferPrototype, ByteRequest | None]] + self, requests: Iterable[tuple[str, BufferClassLike | None, ByteRequest | None]] ) -> AsyncGenerator[tuple[str, Buffer | None], None]: async for req in self._store._get_many(requests): yield req diff --git a/src/zarr/storage/_zip.py b/src/zarr/storage/_zip.py index 72bf9e335a..eb3b33f514 100644 --- a/src/zarr/storage/_zip.py +++ b/src/zarr/storage/_zip.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal from zarr.abc.store import ( + BufferClassLike, ByteRequest, OffsetByteRequest, RangeByteRequest, @@ -146,19 +147,24 @@ def __eq__(self, other: object) -> bool: def _get( self, key: str, - prototype: BufferPrototype, + prototype: BufferClassLike, byte_range: ByteRequest | None = None, ) -> Buffer | None: if not self._is_open: self._sync_open() + # Extract buffer class from BufferLike + if isinstance(prototype, BufferPrototype): + buffer_cls = prototype.buffer + else: + buffer_cls = prototype # docstring inherited try: with self._zf.open(key) as f: # will raise KeyError if byte_range is None: - return prototype.buffer.from_bytes(f.read()) + return buffer_cls.from_bytes(f.read()) elif isinstance(byte_range, RangeByteRequest): f.seek(byte_range.start) - return prototype.buffer.from_bytes(f.read(byte_range.end - f.tell())) + return buffer_cls.from_bytes(f.read(byte_range.end - f.tell())) size = f.seek(0, os.SEEK_END) if isinstance(byte_range, OffsetByteRequest): f.seek(byte_range.offset) @@ -166,17 +172,19 @@ def _get( f.seek(max(0, size - byte_range.suffix)) else: raise TypeError(f"Unexpected byte_range, got {byte_range}.") - return prototype.buffer.from_bytes(f.read()) + return buffer_cls.from_bytes(f.read()) except KeyError: return None async def get( self, key: str, - prototype: BufferPrototype, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited + if prototype is None: + prototype = self._get_default_buffer_class() assert isinstance(key, str) with self._lock: @@ -184,10 +192,12 @@ async def get( async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferClassLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited + if prototype is None: + prototype = self._get_default_buffer_class() out = [] with self._lock: for key, byte_range in key_ranges: diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 5daf8284eb..bc6a8e09e6 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -12,18 +12,18 @@ from typing import Any from zarr.abc.store import ByteRequest - from zarr.core.buffer.core import BufferPrototype import pytest from zarr.abc.store import ( + BufferClassLike, ByteRequest, OffsetByteRequest, RangeByteRequest, Store, SuffixByteRequest, ) -from zarr.core.buffer import Buffer, default_buffer_prototype +from zarr.core.buffer import Buffer, cpu, default_buffer_prototype from zarr.core.sync import _collect_aiterator, sync from zarr.storage._utils import _normalize_byte_range_index from zarr.testing.utils import assert_bytes_equal @@ -108,7 +108,7 @@ async def test_serializable_store(self, store: S) -> None: data_buf = self.buffer_cls.from_bytes(b"\x01\x02\x03\x04") key = "foo" await store.set(key, data_buf) - observed = await store.get(key, prototype=default_buffer_prototype()) + observed = await store.get(key) assert_bytes_equal(observed, data_buf) def test_store_read_only(self, store: S) -> None: @@ -202,6 +202,15 @@ async def test_with_read_only_store(self, open_kwargs: dict[str, Any]) -> None: ): await reader.delete("foo") + @pytest.mark.parametrize( + "prototype", + [ + None, # Should use store's default buffer class + default_buffer_prototype(), # BufferPrototype instance + default_buffer_prototype().buffer, # Raw Buffer class (cpu.Buffer) + ], + ids=["prototype=None", "prototype=BufferPrototype", "prototype=Buffer"], + ) @pytest.mark.parametrize("key", ["c/0", "foo/c/0.0", "foo/0/0"]) @pytest.mark.parametrize( ("data", "byte_range"), @@ -213,13 +222,20 @@ async def test_with_read_only_store(self, open_kwargs: dict[str, Any]) -> None: (b"", None), ], ) - async def test_get(self, store: S, key: str, data: bytes, byte_range: ByteRequest) -> None: + async def test_get( + self, + store: S, + key: str, + data: bytes, + byte_range: ByteRequest, + prototype: BufferClassLike | None, + ) -> None: """ Ensure that data can be read from the store using the store.get method. """ data_buf = self.buffer_cls.from_bytes(data) await self.set(store, key, data_buf) - observed = await store.get(key, prototype=default_buffer_prototype(), byte_range=byte_range) + observed = await store.get(key, prototype=prototype, byte_range=byte_range) start, stop = _normalize_byte_range_index(data_buf, byte_range=byte_range) expected = data_buf[start:stop] assert_bytes_equal(observed, expected) @@ -232,7 +248,7 @@ async def test_get_not_open(self, store_not_open: S) -> None: data_buf = self.buffer_cls.from_bytes(b"\x01\x02\x03\x04") key = "c/0" await self.set(store_not_open, key, data_buf) - observed = await store_not_open.get(key, prototype=default_buffer_prototype()) + observed = await store_not_open.get(key) assert_bytes_equal(observed, data_buf) async def test_get_raises(self, store: S) -> None: @@ -256,7 +272,7 @@ async def test_get_many(self, store: S) -> None: store._get_many( zip( keys, - (default_buffer_prototype(),) * len(keys), + (None,) * len(keys), (None,) * len(keys), strict=False, ) @@ -332,6 +348,15 @@ async def test_set_many(self, store: S) -> None: for k, v in store_dict.items(): assert (await self.get(store, k)).to_bytes() == v.to_bytes() + @pytest.mark.parametrize( + "prototype", + [ + None, # Should use store's default buffer class + default_buffer_prototype(), # BufferPrototype instance + default_buffer_prototype().buffer, # Raw Buffer class (cpu.Buffer) + ], + ids=["prototype=None", "prototype=BufferPrototype", "prototype=Buffer"], + ) @pytest.mark.parametrize( "key_ranges", [ @@ -346,16 +371,14 @@ async def test_set_many(self, store: S) -> None: ], ) async def test_get_partial_values( - self, store: S, key_ranges: list[tuple[str, ByteRequest]] + self, store: S, key_ranges: list[tuple[str, ByteRequest]], prototype: BufferClassLike | None ) -> None: # put all of the data for key, _ in key_ranges: await self.set(store, key, self.buffer_cls.from_bytes(bytes(key, encoding="utf-8"))) # read back just part of it - observed_maybe = await store.get_partial_values( - prototype=default_buffer_prototype(), key_ranges=key_ranges - ) + observed_maybe = await store.get_partial_values(prototype=prototype, key_ranges=key_ranges) observed: list[Buffer] = [] expected: list[Buffer] = [] @@ -366,9 +389,7 @@ async def test_get_partial_values( for idx in range(len(observed)): key, byte_range = key_ranges[idx] - result = await store.get( - key, prototype=default_buffer_prototype(), byte_range=byte_range - ) + result = await store.get(key, prototype=cpu.Buffer, byte_range=byte_range) assert result is not None expected.append(result) @@ -519,53 +540,53 @@ async def test_set_if_not_exists(self, store: S) -> None: new = self.buffer_cls.from_bytes(b"1111") await store.set_if_not_exists("k", new) # no error - result = await store.get(key, default_buffer_prototype()) + result = await store.get(key) assert result == data_buf await store.set_if_not_exists("k2", new) # no error - result = await store.get("k2", default_buffer_prototype()) + result = await store.get("k2") assert result == new async def test_get_bytes(self, store: S) -> None: """ - Test that the get_bytes method reads bytes. + Test that the _get_bytes method reads bytes. """ data = b"hello world" key = "zarr.json" await self.set(store, key, self.buffer_cls.from_bytes(data)) - assert await store._get_bytes(key, prototype=default_buffer_prototype()) == data + assert await store._get_bytes(key) == data with pytest.raises(FileNotFoundError): - await store._get_bytes("nonexistent_key", prototype=default_buffer_prototype()) + await store._get_bytes("nonexistent_key") def test_get_bytes_sync(self, store: S) -> None: """ - Test that the get_bytes_sync method reads bytes. + Test that the _get_bytes_sync method reads bytes. """ data = b"hello world" key = "zarr.json" sync(self.set(store, key, self.buffer_cls.from_bytes(data))) - assert store._get_bytes_sync(key, prototype=default_buffer_prototype()) == data + assert store._get_bytes_sync(key) == data async def test_get_json(self, store: S) -> None: """ - Test that the get_json method reads json. + Test that the _get_json method reads json. """ data = {"foo": "bar"} data_bytes = json.dumps(data).encode("utf-8") key = "zarr.json" await self.set(store, key, self.buffer_cls.from_bytes(data_bytes)) - assert await store._get_json(key, prototype=default_buffer_prototype()) == data + assert await store._get_json(key) == data def test_get_json_sync(self, store: S) -> None: """ - Test that the get_json method reads json. + Test that the _get_json_sync method reads json. """ data = {"foo": "bar"} data_bytes = json.dumps(data).encode("utf-8") key = "zarr.json" sync(self.set(store, key, self.buffer_cls.from_bytes(data_bytes))) - assert store._get_json_sync(key, prototype=default_buffer_prototype()) == data + assert store._get_json_sync(key) == data class LatencyStore(WrapperStore[Store]): @@ -604,7 +625,10 @@ async def set(self, key: str, value: Buffer) -> None: await self._store.set(key, value) async def get( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, + key: str, + prototype: BufferClassLike | None = None, + byte_range: ByteRequest | None = None, ) -> Buffer | None: """ Add latency to the ``get`` method. @@ -615,8 +639,12 @@ async def get( ---------- key : str The key to get - prototype : BufferPrototype - The BufferPrototype to use. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer class for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional An optional byte range. diff --git a/tests/test_api/test_asynchronous.py b/tests/test_api/test_asynchronous.py index 362195e858..e8035493ae 100644 --- a/tests/test_api/test_asynchronous.py +++ b/tests/test_api/test_asynchronous.py @@ -9,7 +9,7 @@ from zarr import create_array from zarr.api.asynchronous import _get_shape_chunks, _like_args, group, open -from zarr.core.buffer.core import default_buffer_prototype +from zarr.buffer import cpu from zarr.core.group import AsyncGroup if TYPE_CHECKING: @@ -101,7 +101,7 @@ async def test_open_no_array() -> None: This behavior makes no sense but we should still test it. """ store = { - "zarr.json": default_buffer_prototype().buffer.from_bytes( + "zarr.json": cpu.Buffer.from_bytes( json.dumps({"zarr_format": 3, "node_type": "group"}).encode("utf-8") ) } diff --git a/tests/test_array.py b/tests/test_array.py index 67be294827..7b17e63610 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -941,7 +941,7 @@ def test_write_empty_chunks_negative_zero( async def test_special_complex_fill_values_roundtrip(fill_value: Any, expected: list[Any]) -> None: store = MemoryStore() zarr.create_array(store=store, shape=(1,), dtype=np.complex64, fill_value=fill_value) - content = await store.get("zarr.json", prototype=default_buffer_prototype()) + content = await store.get("zarr.json") assert content is not None actual = json.loads(content.to_bytes()) assert actual["fill_value"] == expected diff --git a/tests/test_codecs/test_blosc.py b/tests/test_codecs/test_blosc.py index 6f4821f8b1..639e5c2ca5 100644 --- a/tests/test_codecs/test_blosc.py +++ b/tests/test_codecs/test_blosc.py @@ -28,7 +28,7 @@ async def test_blosc_evolve(dtype: str) -> None: fill_value=0, compressors=BloscCodec(), ) - buf = await store.get(f"{path}/zarr.json", prototype=default_buffer_prototype()) + buf = await store.get(f"{path}/zarr.json") assert buf is not None zarr_json = json.loads(buf.to_bytes()) blosc_configuration_json = zarr_json["codecs"][1]["configuration"] @@ -49,7 +49,7 @@ async def test_blosc_evolve(dtype: str) -> None: fill_value=0, compressors=BloscCodec(), ) - buf = await store.get(f"{path2}/zarr.json", prototype=default_buffer_prototype()) + buf = await store.get(f"{path2}/zarr.json") assert buf is not None zarr_json = json.loads(buf.to_bytes()) blosc_configuration_json = zarr_json["codecs"][0]["configuration"]["codecs"][1]["configuration"] @@ -99,7 +99,7 @@ async def test_typesize() -> None: a = np.arange(1000000, dtype=np.uint64) codecs = [zarr.codecs.BytesCodec(), zarr.codecs.BloscCodec()] z = zarr.array(a, chunks=(10000), codecs=codecs) - data = await z.store.get("c/0", prototype=default_buffer_prototype()) + data = await z.store.get("c/0") assert data is not None bytes = data.to_bytes() size = len(bytes) diff --git a/tests/test_codecs/test_codecs.py b/tests/test_codecs/test_codecs.py index eae7168d49..3994a8a55e 100644 --- a/tests/test_codecs/test_codecs.py +++ b/tests/test_codecs/test_codecs.py @@ -17,7 +17,6 @@ ShardingCodec, TransposeCodec, ) -from zarr.core.buffer import default_buffer_prototype from zarr.core.indexing import BasicSelection, morton_order_iter from zarr.core.metadata.v3 import ArrayV3Metadata from zarr.dtype import UInt8 @@ -253,7 +252,7 @@ async def test_delete_empty_chunks(store: Store) -> None: await _AsyncArrayProxy(a)[:16, :16].set(np.zeros((16, 16))) await _AsyncArrayProxy(a)[:16, :16].set(data) assert np.array_equal(await _AsyncArrayProxy(a)[:16, :16].get(), data) - assert await store.get(f"{path}/c0/0", prototype=default_buffer_prototype()) is None + assert await store.get(f"{path}/c0/0") is None @pytest.mark.parametrize("store", ["local", "memory"], indirect=["store"]) @@ -289,7 +288,7 @@ async def test_dimension_names(store: Store) -> None: assert isinstance(meta := (await AsyncArray.open(spath2)).metadata, ArrayV3Metadata) assert meta.dimension_names is None - zarr_json_buffer = await store.get(f"{path2}/zarr.json", prototype=default_buffer_prototype()) + zarr_json_buffer = await store.get(f"{path2}/zarr.json") assert zarr_json_buffer is not None assert "dimension_names" not in json.loads(zarr_json_buffer.to_bytes()) @@ -351,15 +350,15 @@ async def test_resize(store: Store) -> None: ) await _AsyncArrayProxy(a)[:16, :18].set(data) - assert await store.get(f"{path}/1.1", prototype=default_buffer_prototype()) is not None - assert await store.get(f"{path}/0.0", prototype=default_buffer_prototype()) is not None - assert await store.get(f"{path}/0.1", prototype=default_buffer_prototype()) is not None - assert await store.get(f"{path}/1.0", prototype=default_buffer_prototype()) is not None + assert await store.get(f"{path}/1.1") is not None + assert await store.get(f"{path}/0.0") is not None + assert await store.get(f"{path}/0.1") is not None + assert await store.get(f"{path}/1.0") is not None await a.resize((10, 12)) assert a.metadata.shape == (10, 12) assert a.shape == (10, 12) - assert await store.get(f"{path}/0.0", prototype=default_buffer_prototype()) is not None - assert await store.get(f"{path}/0.1", prototype=default_buffer_prototype()) is not None - assert await store.get(f"{path}/1.0", prototype=default_buffer_prototype()) is None - assert await store.get(f"{path}/1.1", prototype=default_buffer_prototype()) is None + assert await store.get(f"{path}/0.0") is not None + assert await store.get(f"{path}/0.1") is not None + assert await store.get(f"{path}/1.0") is None + assert await store.get(f"{path}/1.1") is None diff --git a/tests/test_codecs/test_sharding.py b/tests/test_codecs/test_sharding.py index 7eb4deccbf..11fae5fbbc 100644 --- a/tests/test_codecs/test_sharding.py +++ b/tests/test_codecs/test_sharding.py @@ -17,7 +17,7 @@ ShardingCodecIndexLocation, TransposeCodec, ) -from zarr.core.buffer import NDArrayLike, default_buffer_prototype +from zarr.core.buffer import NDArrayLike from zarr.errors import ZarrUserWarning from zarr.storage import StorePath, ZipStore @@ -398,8 +398,8 @@ async def test_delete_empty_shards(store: Store) -> None: data = np.ones((16, 16), dtype="uint16") data[:8, :8] = 0 assert np.array_equal(data, await _AsyncArrayProxy(a)[:, :].get()) - assert await store.get(f"{path}/c/1/0", prototype=default_buffer_prototype()) is None - chunk_bytes = await store.get(f"{path}/c/0/0", prototype=default_buffer_prototype()) + assert await store.get(f"{path}/c/1/0") is None + chunk_bytes = await store.get(f"{path}/c/0/0") assert chunk_bytes is not None assert len(chunk_bytes) == 16 * 2 + 8 * 8 * 2 + 4 diff --git a/tests/test_dtype/test_wrapper.py b/tests/test_dtype/test_wrapper.py index cc365e86d4..7e33920eec 100644 --- a/tests/test_dtype/test_wrapper.py +++ b/tests/test_dtype/test_wrapper.py @@ -80,7 +80,7 @@ class BaseTestZDType: valid_json_v2: ClassVar[tuple[DTypeSpec_V2, ...]] = () invalid_json_v2: ClassVar[tuple[str | dict[str, object] | list[object], ...]] = () - valid_json_v3: ClassVar[tuple[DTypeSpec_V3, ...]] = () + valid_json_v3: ClassVar[tuple[DTypeSpec_V3, ...]] = () # type: ignore[valid-type] invalid_json_v3: ClassVar[tuple[str | dict[str, object], ...]] = () # for testing scalar round-trip serialization, we need a tuple of (data type json, scalar json) @@ -120,9 +120,9 @@ def test_from_json_roundtrip_v2(self, valid_json_v2: DTypeSpec_V2) -> None: assert zdtype.to_json(zarr_format=2) == valid_json_v2 @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") - def test_from_json_roundtrip_v3(self, valid_json_v3: DTypeSpec_V3) -> None: + def test_from_json_roundtrip_v3(self, valid_json_v3: DTypeSpec_V3) -> None: # type: ignore[valid-type] zdtype = self.test_cls.from_json(valid_json_v3, zarr_format=3) - assert zdtype.to_json(zarr_format=3) == valid_json_v3 + assert zdtype.to_json(zarr_format=3) == valid_json_v3 # type: ignore[operator] def test_scalar_roundtrip_v2(self, scalar_v2_params: tuple[ZDType[Any, Any], Any]) -> None: zdtype, scalar_json = scalar_v2_params diff --git a/tests/test_experimental/test_cache_store.py b/tests/test_experimental/test_cache_store.py index d4a45f78f1..26a0a9931e 100644 --- a/tests/test_experimental/test_cache_store.py +++ b/tests/test_experimental/test_cache_store.py @@ -8,7 +8,6 @@ import pytest from zarr.abc.store import Store -from zarr.core.buffer.core import default_buffer_prototype from zarr.core.buffer.cpu import Buffer as CPUBuffer from zarr.experimental.cache_store import CacheStore from zarr.storage import MemoryStore @@ -43,7 +42,7 @@ async def test_basic_caching(self, cached_store: CacheStore, source_store: Store assert await cached_store._cache.exists("test_key") # Retrieve and verify caching works - result = await cached_store.get("test_key", default_buffer_prototype()) + result = await cached_store.get("test_key") assert result is not None assert result.to_bytes() == b"test data" @@ -56,7 +55,7 @@ async def test_cache_miss_and_population( await source_store.set("source_key", test_data) # First access should miss cache but populate it - result = await cached_store.get("source_key", default_buffer_prototype()) + result = await cached_store.get("source_key") assert result is not None assert result.to_bytes() == b"source data" @@ -91,7 +90,7 @@ async def test_cache_expiration(self) -> None: # Skip freshness check if method doesn't exist await asyncio.sleep(1.1) # Just verify the data is still accessible - result = await cached_store.get("expire_key", default_buffer_prototype()) + result = await cached_store.get("expire_key") assert result is not None async def test_cache_set_data_false(self, source_store: Store, cache_store: Store) -> None: @@ -170,7 +169,7 @@ async def test_stale_cache_refresh(self) -> None: await source_store.set("refresh_key", new_data) # Access should refresh from source when cache is stale - result = await cached_store.get("refresh_key", default_buffer_prototype()) + result = await cached_store.get("refresh_key") assert result is not None assert result.to_bytes() == b"new data" @@ -204,7 +203,7 @@ async def test_cache_returns_cached_data_for_performance( cached_store.key_insert_times["orphan_key"] = time.monotonic() # Cache should return data for performance (no source verification) - result = await cached_store.get("orphan_key", default_buffer_prototype()) + result = await cached_store.get("orphan_key") assert result is not None assert result.to_bytes() == b"orphaned data" @@ -230,7 +229,7 @@ async def test_cache_coherency_through_expiration(self) -> None: await source_store.delete("coherency_key") # Cache should still return cached data (performance optimization) - result = await cached_store.get("coherency_key", default_buffer_prototype()) + result = await cached_store.get("coherency_key") assert result is not None assert result.to_bytes() == b"original data" @@ -238,7 +237,7 @@ async def test_cache_coherency_through_expiration(self) -> None: await asyncio.sleep(1.1) # Now stale cache should be refreshed from source - result = await cached_store.get("coherency_key", default_buffer_prototype()) + result = await cached_store.get("coherency_key") assert result is None # Key no longer exists in source async def test_cache_info(self, cached_store: CacheStore) -> None: @@ -447,7 +446,7 @@ async def test_get_nonexistent_key(self) -> None: cached_store = CacheStore(source_store, cache_store=cache_store) # Try to get nonexistent key - result = await cached_store.get("nonexistent", default_buffer_prototype()) + result = await cached_store.get("nonexistent") assert result is None # Should not create any cache entries @@ -544,7 +543,7 @@ async def test_get_no_cache_delete_tracking(self) -> None: assert "phantom_key" in cached_store.key_insert_times # Now try to get it - since it's not in source, should clean up tracking - result = await cached_store._get_no_cache("phantom_key", default_buffer_prototype()) + result = await cached_store._get_no_cache("phantom_key") assert result is None # Should have cleaned up tracking @@ -626,7 +625,7 @@ async def test_concurrent_get_and_evict(self) -> None: # Concurrent: read key1 while adding key3 (triggers eviction) async def read_key() -> None: for _ in range(100): - await cached_store.get("key1", default_buffer_prototype()) + await cached_store.get("key1") async def write_key() -> None: for i in range(10): @@ -802,11 +801,11 @@ async def test_cache_stats_method(self) -> None: await source_store.set("key1", buffer) # First get is a miss (not in cache yet) - result1 = await cached_store.get("key1", default_buffer_prototype()) + result1 = await cached_store.get("key1") assert result1 is not None # Second get is a hit (now in cache) - result2 = await cached_store.get("key1", default_buffer_prototype()) + result2 = await cached_store.get("key1") assert result2 is not None stats = cached_store.cache_stats() diff --git a/tests/test_group.py b/tests/test_group.py index 6f1f4e68fa..471aa91336 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -20,9 +20,9 @@ import zarr.storage from zarr import Array, AsyncArray, AsyncGroup, Group from zarr.abc.store import Store +from zarr.buffer import cpu from zarr.core import sync_group from zarr.core._info import GroupInfo -from zarr.core.buffer import default_buffer_prototype from zarr.core.config import config as zarr_config from zarr.core.dtype.common import unpack_dtype_json from zarr.core.dtype.npy.int import UInt8 @@ -196,7 +196,7 @@ def test_group_members(store: Store, zarr_format: ZarrFormat, consolidated_metad sync( store.set( f"{path}/extra_object-1", - default_buffer_prototype().buffer.from_bytes(b"000000"), + cpu.Buffer.from_bytes(b"000000"), ) ) # add an extra object under a directory-like prefix in the domain of the group. @@ -205,7 +205,7 @@ def test_group_members(store: Store, zarr_format: ZarrFormat, consolidated_metad sync( store.set( f"{path}/extra_directory/extra_object-2", - default_buffer_prototype().buffer.from_bytes(b"000000"), + cpu.Buffer.from_bytes(b"000000"), ) ) @@ -1481,12 +1481,10 @@ def test_open_mutable_mapping_sync(): async def test_open_ambiguous_node(): - zarr_json_bytes = default_buffer_prototype().buffer.from_bytes( + zarr_json_bytes = cpu.Buffer.from_bytes( json.dumps({"zarr_format": 3, "node_type": "group"}).encode("utf-8") ) - zgroup_bytes = default_buffer_prototype().buffer.from_bytes( - json.dumps({"zarr_format": 2}).encode("utf-8") - ) + zgroup_bytes = cpu.Buffer.from_bytes(json.dumps({"zarr_format": 2}).encode("utf-8")) store: dict[str, Buffer] = {"zarr.json": zarr_json_bytes, ".zgroup": zgroup_bytes} with pytest.warns( ZarrUserWarning, diff --git a/tests/test_indexing.py b/tests/test_indexing.py index c0bf7dd270..631f364e32 100644 --- a/tests/test_indexing.py +++ b/tests/test_indexing.py @@ -12,7 +12,6 @@ import zarr from zarr import Array -from zarr.core.buffer import default_buffer_prototype from zarr.core.indexing import ( BasicSelection, CoordinateSelection, @@ -155,7 +154,7 @@ def test_get_basic_selection_0d(store: StorePath, use_out: bool, value: Any, dty if use_out: # test out param - b = default_buffer_prototype().nd_buffer.from_numpy_array(np.zeros_like(arr_np)) + b = get_ndbuffer_class().from_numpy_array(np.zeros_like(arr_np)) arr_z.get_basic_selection(Ellipsis, out=b) assert_array_equal(arr_np, b.as_ndarray_like()) @@ -268,9 +267,7 @@ def _test_get_basic_selection( assert_array_equal(expect, actual) # test out param - b = default_buffer_prototype().nd_buffer.from_numpy_array( - np.empty(shape=expect.shape, dtype=expect.dtype) - ) + b = get_ndbuffer_class().from_numpy_array(np.empty(shape=expect.shape, dtype=expect.dtype)) z.get_basic_selection(selection, out=b) assert_array_equal(expect, b.as_numpy_array()) diff --git a/tests/test_metadata/test_consolidated.py b/tests/test_metadata/test_consolidated.py index 9e8b763ef7..b375314521 100644 --- a/tests/test_metadata/test_consolidated.py +++ b/tests/test_metadata/test_consolidated.py @@ -17,7 +17,7 @@ open, open_consolidated, ) -from zarr.core.buffer import cpu, default_buffer_prototype +from zarr.core.buffer import cpu from zarr.core.dtype import parse_dtype from zarr.core.group import ConsolidatedMetadata, GroupMetadata from zarr.core.metadata import ArrayV3Metadata @@ -192,9 +192,7 @@ async def test_consolidated(self, memory_store_with_hierarchy: Store) -> None: group4 = await open_consolidated(store=memory_store_with_hierarchy) assert group4.metadata == expected - buf = await memory_store_with_hierarchy.get( - "zarr.json", prototype=default_buffer_prototype() - ) + buf = await memory_store_with_hierarchy.get("zarr.json") assert buf is not None result_raw = json.loads(buf.to_bytes())["consolidated_metadata"] @@ -724,7 +722,7 @@ async def test_consolidated_metadata_encodes_special_chars( await zarr.api.asynchronous.consolidate_metadata(memory_store) root = await group(store=memory_store, zarr_format=zarr_format) - root_buffer = root.metadata.to_buffer_dict(default_buffer_prototype()) + root_buffer = root.metadata.to_buffer_dict(cpu.Buffer) if zarr_format == 2: root_metadata = json.loads(root_buffer[".zmetadata"].to_bytes().decode("utf-8"))["metadata"] diff --git a/tests/test_metadata/test_v2.py b/tests/test_metadata/test_v2.py index 8c3082e924..69fb8298ef 100644 --- a/tests/test_metadata/test_v2.py +++ b/tests/test_metadata/test_v2.py @@ -9,7 +9,6 @@ import zarr.api.asynchronous import zarr.storage from zarr.core.buffer import cpu -from zarr.core.buffer.core import default_buffer_prototype from zarr.core.dtype.npy.float import Float32, Float64 from zarr.core.dtype.npy.int import Int16 from zarr.core.group import ConsolidatedMetadata, GroupMetadata @@ -318,9 +317,7 @@ def test_zstd_checksum() -> None: compressors=compressor_config, zarr_format=2, ) - metadata = json.loads( - arr.metadata.to_buffer_dict(default_buffer_prototype())[".zarray"].to_bytes() - ) + metadata = json.loads(arr.metadata.to_buffer_dict()[".zarray"].to_bytes()) assert "checksum" not in metadata["compressor"] diff --git a/tests/test_metadata/test_v3.py b/tests/test_metadata/test_v3.py index 01ed921053..f0fd34accb 100644 --- a/tests/test_metadata/test_v3.py +++ b/tests/test_metadata/test_v3.py @@ -9,7 +9,6 @@ from zarr import consolidate_metadata, create_group from zarr.codecs.bytes import BytesCodec -from zarr.core.buffer import default_buffer_prototype from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding from zarr.core.config import config from zarr.core.dtype import UInt8, get_data_type_from_native_dtype @@ -261,7 +260,7 @@ def test_metadata_to_dict( def test_json_indent(indent: int) -> None: with config.set({"json_indent": indent}): m = GroupMetadata() - d = m.to_buffer_dict(default_buffer_prototype())["zarr.json"].to_bytes() + d = m.to_buffer_dict()["zarr.json"].to_bytes() assert d == json.dumps(json.loads(d), indent=indent).encode() @@ -283,7 +282,7 @@ async def test_datetime_metadata(fill_value: int, precision: Literal["ns", "D"]) } metadata = ArrayV3Metadata.from_dict(metadata_dict) # ensure there isn't a TypeError here. - d = metadata.to_buffer_dict(default_buffer_prototype()) + d = metadata.to_buffer_dict() result = json.loads(d["zarr.json"].to_bytes()) assert result["fill_value"] == fill_value @@ -321,7 +320,7 @@ async def test_special_float_fill_values(fill_value: str) -> None: "fill_value": fill_value, # this is not a valid fill value for uint8 } m = ArrayV3Metadata.from_dict(metadata_dict) - d = json.loads(m.to_buffer_dict(default_buffer_prototype())["zarr.json"].to_bytes()) + d = json.loads(m.to_buffer_dict()["zarr.json"].to_bytes()) assert m.fill_value is not None if fill_value == "NaN": assert np.isnan(m.fill_value) diff --git a/tests/test_properties.py b/tests/test_properties.py index 705cfd1b59..ecd9520743 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -6,8 +6,6 @@ import pytest from numpy.testing import assert_array_equal -from zarr.core.buffer import default_buffer_prototype - pytest.importorskip("hypothesis") import hypothesis.extra.numpy as npst @@ -91,15 +89,9 @@ def test_array_creates_implicit_groups(array): for i in range(len(ancestry)): parent = "/".join(ancestry[: i + 1]) if array.metadata.zarr_format == 2: - assert ( - sync(array.store.get(f"{parent}/.zgroup", prototype=default_buffer_prototype())) - is not None - ) + assert sync(array.store.get(f"{parent}/.zgroup")) is not None elif array.metadata.zarr_format == 3: - assert ( - sync(array.store.get(f"{parent}/zarr.json", prototype=default_buffer_prototype())) - is not None - ) + assert sync(array.store.get(f"{parent}/zarr.json")) is not None # this decorator removes timeout; not ideal but it should avoid intermittent CI failures @@ -211,10 +203,10 @@ async def test_roundtrip_array_metadata_from_store( prefixed with "0/", and then reads them back. The test asserts that each retrieved buffer exactly matches the original buffer. """ - asdict = meta.to_buffer_dict(prototype=default_buffer_prototype()) + asdict = meta.to_buffer_dict() for key, expected in asdict.items(): await store.set(f"0/{key}", expected) - actual = await store.get(f"0/{key}", prototype=default_buffer_prototype()) + actual = await store.get(f"0/{key}") assert actual == expected @@ -236,7 +228,7 @@ def test_roundtrip_array_metadata_from_json(data: st.DataObject, zarr_format: in cases like NaN, Infinity, complex numbers, and datetime values). """ metadata = data.draw(array_metadata(zarr_formats=st.just(zarr_format))) - buffer_dict = metadata.to_buffer_dict(prototype=default_buffer_prototype()) + buffer_dict = metadata.to_buffer_dict() if zarr_format == 2: zarray_dict = json.loads(buffer_dict[ZARRAY_JSON].to_bytes().decode()) diff --git a/tests/test_store/test_fsspec.py b/tests/test_store/test_fsspec.py index a2c07b7ed1..ed32815459 100644 --- a/tests/test_store/test_fsspec.py +++ b/tests/test_store/test_fsspec.py @@ -12,7 +12,7 @@ import zarr.api.asynchronous from zarr import Array from zarr.abc.store import OffsetByteRequest -from zarr.core.buffer import Buffer, cpu, default_buffer_prototype +from zarr.core.buffer import Buffer, cpu from zarr.core.sync import _collect_aiterator, sync from zarr.errors import ZarrUserWarning from zarr.storage import FsspecStore @@ -119,11 +119,11 @@ async def test_basic() -> None: data = b"hello" await store.set("foo", cpu.Buffer.from_bytes(data)) assert await store.exists("foo") - buf = await store.get("foo", prototype=default_buffer_prototype()) + buf = await store.get("foo") assert buf is not None assert buf.to_bytes() == data out = await store.get_partial_values( - prototype=default_buffer_prototype(), key_ranges=[("foo", OffsetByteRequest(1))] + key_ranges=[("foo", OffsetByteRequest(1))], prototype=cpu.Buffer ) assert out[0] is not None assert out[0].to_bytes() == data[1:] diff --git a/tests/test_store/test_local.py b/tests/test_store/test_local.py index bdc9b48121..c9dda545d7 100644 --- a/tests/test_store/test_local.py +++ b/tests/test_store/test_local.py @@ -130,7 +130,7 @@ async def test_get_bytes_with_prototype_none( def test_get_bytes_sync_with_prototype_none( self, store: LocalStore, buffer_cls: None | BufferPrototype ) -> None: - """Test that get_bytes_sync works with prototype=None.""" + """Test that _get_bytes_sync works with prototype=None.""" data = b"hello world" key = "test_key" sync(self.set(store, key, self.buffer_cls.from_bytes(data))) @@ -142,7 +142,7 @@ def test_get_bytes_sync_with_prototype_none( async def test_get_json_with_prototype_none( self, store: LocalStore, buffer_cls: None | BufferPrototype ) -> None: - """Test that get_json works with prototype=None.""" + """Test that _get_json works with prototype=None.""" data = {"foo": "bar", "number": 42} key = "test.json" await self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode())) @@ -154,7 +154,7 @@ async def test_get_json_with_prototype_none( def test_get_json_sync_with_prototype_none( self, store: LocalStore, buffer_cls: None | BufferPrototype ) -> None: - """Test that get_json_sync works with prototype=None.""" + """Test that _get_json_sync works with prototype=None.""" data = {"foo": "bar", "number": 42} key = "test.json" sync(self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode()))) diff --git a/tests/test_store/test_logging.py b/tests/test_store/test_logging.py index fa566e45aa..46592f28c5 100644 --- a/tests/test_store/test_logging.py +++ b/tests/test_store/test_logging.py @@ -6,7 +6,7 @@ import pytest import zarr -from zarr.core.buffer import Buffer, cpu, default_buffer_prototype +from zarr.buffer import cpu from zarr.storage import LocalStore, LoggingStore from zarr.testing.store import StoreTests @@ -14,6 +14,7 @@ from pathlib import Path from zarr.abc.store import Store + from zarr.core.buffer.core import Buffer class StoreKwargs(TypedDict): @@ -68,7 +69,7 @@ async def test_default_handler( logging.getLogger().removeHandler(h) # Test logs are sent to stdout wrapped = LoggingStore(store=local_store) - buffer = default_buffer_prototype().buffer + buffer = cpu.Buffer res = await wrapped.set("foo/bar/c/0", buffer.from_bytes(b"\x01\x02\x03\x04")) # type: ignore[func-returns-value] assert res is None captured = capsys.readouterr() @@ -90,7 +91,7 @@ def test_is_open_setter_raises(self, store: LoggingStore[LocalStore]) -> None: @pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) async def test_logging_store(store: Store, caplog: pytest.LogCaptureFixture) -> None: wrapped = LoggingStore(store=store, log_level="DEBUG") - buffer = default_buffer_prototype().buffer + buffer = cpu.Buffer caplog.clear() res = await wrapped.set("foo/bar/c/0", buffer.from_bytes(b"\x01\x02\x03\x04")) # type: ignore[func-returns-value] diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index 03c8b24271..c424d864ee 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -95,7 +95,7 @@ async def test_get_bytes_with_prototype_none( def test_get_bytes_sync_with_prototype_none( self, store: MemoryStore, buffer_cls: None | BufferPrototype ) -> None: - """Test that get_bytes_sync works with prototype=None.""" + """Test that _get_bytes_sync works with prototype=None.""" data = b"hello world" key = "test_key" sync(self.set(store, key, self.buffer_cls.from_bytes(data))) @@ -107,7 +107,7 @@ def test_get_bytes_sync_with_prototype_none( async def test_get_json_with_prototype_none( self, store: MemoryStore, buffer_cls: None | BufferPrototype ) -> None: - """Test that get_json works with prototype=None.""" + """Test that _get_json works with prototype=None.""" data = {"foo": "bar", "number": 42} key = "test.json" await self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode())) @@ -119,7 +119,7 @@ async def test_get_json_with_prototype_none( def test_get_json_sync_with_prototype_none( self, store: MemoryStore, buffer_cls: None | BufferPrototype ) -> None: - """Test that get_json_sync works with prototype=None.""" + """Test that _get_json_sync works with prototype=None.""" data = {"foo": "bar", "number": 42} key = "test.json" sync(self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode()))) diff --git a/tests/test_store/test_wrapper.py b/tests/test_store/test_wrapper.py index b34a63d5d0..c6ccf3cbb9 100644 --- a/tests/test_store/test_wrapper.py +++ b/tests/test_store/test_wrapper.py @@ -4,7 +4,7 @@ import pytest -from zarr.abc.store import ByteRequest, Store +from zarr.abc.store import BufferClassLike, ByteRequest, Store from zarr.core.buffer import Buffer from zarr.core.buffer.cpu import Buffer as CPUBuffer from zarr.core.buffer.cpu import buffer_prototype @@ -14,8 +14,6 @@ if TYPE_CHECKING: from pathlib import Path - from zarr.core.buffer.core import BufferPrototype - class StoreKwargs(TypedDict): store: LocalStore @@ -111,10 +109,13 @@ async def test_wrapped_get(store: Store, capsys: pytest.CaptureFixture[str]) -> # define a class that prints when it sets class NoisyGetter(WrapperStore[Any]): async def get( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None - ) -> None: + self, + key: str, + prototype: BufferClassLike | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: print(f"getting {key}") - await super().get(key, prototype=prototype, byte_range=byte_range) + return await super().get(key, prototype=prototype, byte_range=byte_range) key = "foo" value = CPUBuffer.from_bytes(b"bar") diff --git a/tests/test_store/test_zip.py b/tests/test_store/test_zip.py index 744ee82945..a955f5fd22 100644 --- a/tests/test_store/test_zip.py +++ b/tests/test_store/test_zip.py @@ -11,7 +11,7 @@ import zarr from zarr import create_array -from zarr.core.buffer import Buffer, cpu, default_buffer_prototype +from zarr.core.buffer import Buffer, cpu from zarr.core.group import Group from zarr.storage import ZipStore from zarr.testing.store import StoreTests @@ -42,7 +42,7 @@ def store_kwargs(self) -> dict[str, str | bool]: return {"path": temp_path, "mode": "w", "read_only": False} async def get(self, store: ZipStore, key: str) -> Buffer: - buf = store._get(key, prototype=default_buffer_prototype()) + buf = store._get(key, prototype=cpu.Buffer) assert buf is not None return buf diff --git a/tests/test_v2.py b/tests/test_v2.py index cb990f6159..f06f5ab89e 100644 --- a/tests/test_v2.py +++ b/tests/test_v2.py @@ -13,7 +13,6 @@ import zarr.storage from zarr import config from zarr.abc.store import Store -from zarr.core.buffer.core import default_buffer_prototype from zarr.core.dtype import FixedLengthUTF32, Structured, VariableLengthUTF8 from zarr.core.dtype.npy.bytes import NullTerminatedBytes from zarr.core.dtype.wrapper import ZDType @@ -78,7 +77,7 @@ async def test_v2_encode_decode( name="foo", shape=(3,), chunks=(3,), dtype=dtype, fill_value=fill_value, compressor=None ) - result = await store.get("foo/.zarray", zarr.core.buffer.default_buffer_prototype()) + result = await store.get("foo/.zarray") assert result is not None serialized = json.loads(result.to_bytes()) @@ -183,7 +182,7 @@ def test_v2_non_contiguous(numpy_order: Literal["C", "F"], zarr_order: Literal[" arr[6:9, 3:6] = a[6:9, 3:6] # The slice on the RHS is important np.testing.assert_array_equal(arr[6:9, 3:6], a[6:9, 3:6]) - buf = sync(store.get("2.1", default_buffer_prototype())) + buf = sync(store.get("2.1")) assert buf is not None np.testing.assert_array_equal( a[6:9, 3:6],