From 021bd440c3ee2fae13aaa255d722bb582d22b6a0 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sun, 4 Jan 2026 17:29:56 +0100 Subject: [PATCH 01/13] add store routines for getting bytes and json --- src/zarr/abc/store.py | 215 ++++++++++++++++++++++++++++- src/zarr/storage/_common.py | 107 +++++++++++++++ src/zarr/storage/_local.py | 232 +++++++++++++++++++++++++++++++- src/zarr/storage/_memory.py | 232 +++++++++++++++++++++++++++++++- src/zarr/testing/store.py | 41 +++++- tests/test_store/test_local.py | 82 +++++++++++ tests/test_store/test_memory.py | 70 ++++++++++ 7 files changed, 974 insertions(+), 5 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 4b3edf78d1..7d0589c836 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -1,11 +1,14 @@ from __future__ import annotations +import asyncio +import json from abc import ABC, abstractmethod -from asyncio import gather from dataclasses import dataclass from itertools import starmap from typing import TYPE_CHECKING, Literal, Protocol, runtime_checkable +from zarr.core.sync import sync + if TYPE_CHECKING: from collections.abc import AsyncGenerator, AsyncIterator, Iterable from types import TracebackType @@ -206,6 +209,214 @@ async def get( """ ... + async def get_bytes_async( + self, key: str, *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + ) -> bytes: + """ + Retrieve raw bytes from the store asynchronously. + + This is a convenience method that wraps ``get()`` and converts the result + to bytes. Use this when you need the raw byte content of a stored value. + + Parameters + ---------- + key : str + The key identifying the data to retrieve. + prototype : BufferPrototype + The buffer prototype to use for reading the data. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + See Also + -------- + get : Lower-level method that returns a Buffer object. + get_bytes : Synchronous version of this method. + get_json_async : Asynchronous method for retrieving and parsing JSON data. + + Examples + -------- + >>> store = await MemoryStore.open() + >>> await store.set("data", Buffer.from_bytes(b"hello world")) + >>> data = await store.get_bytes_async("data", prototype=default_buffer_prototype()) + >>> print(data) + b'hello world' + """ + buffer = await self.get(key, prototype, byte_range) + if buffer is None: + raise FileNotFoundError(key) + return buffer.to_bytes() + + def get_bytes( + self, key: str = "", *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + ) -> bytes: + """ + Retrieve raw bytes from the store synchronously. + + This is a synchronous wrapper around ``get_bytes_async()``. It should only + be called from non-async code. For async contexts, use ``get_bytes_async()`` + instead. + + Parameters + ---------- + key : str, optional + The key identifying the data to retrieve. Defaults to an empty string. + prototype : BufferPrototype + The buffer prototype to use for reading the data. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + Warnings + -------- + Do not call this method from async functions. Use ``get_bytes_async()`` instead + to avoid blocking the event loop. + + See Also + -------- + get_bytes_async : Asynchronous version of this method. + get_json : Synchronous method for retrieving and parsing JSON data. + + Examples + -------- + >>> store = MemoryStore() + >>> store.set("data", Buffer.from_bytes(b"hello world")) + >>> data = store.get_bytes("data", prototype=default_buffer_prototype()) + >>> print(data) + b'hello world' + """ + + return sync(self.get_bytes_async(key, prototype=prototype, byte_range=byte_range)) + + async def get_json_async( + self, key: str, *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + ) -> Any: + """ + Retrieve and parse JSON data from the store asynchronously. + + This is a convenience method that retrieves bytes from the store and + parses them as JSON. Commonly used for reading Zarr metadata files + like ``zarr.json``. + + Parameters + ---------- + key : str + The key identifying the JSON data to retrieve. + prototype : BufferPrototype + The buffer prototype to use for reading the data. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + See Also + -------- + get_bytes_async : Method for retrieving raw bytes without parsing. + get_json : Synchronous version of this method. + + Examples + -------- + >>> store = await MemoryStore.open() + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> await store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> data = await store.get_json_async("zarr.json", prototype=default_buffer_prototype()) + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + + return json.loads( + await self.get_bytes_async(key, prototype=prototype, byte_range=byte_range) + ) + + def get_json( + self, key: str = "", *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + ) -> Any: + """ + Retrieve and parse JSON data from the store synchronously. + + This is a synchronous wrapper around ``get_json_async()``. It should only + be called from non-async code. For async contexts, use ``get_json_async()`` + instead. + + Parameters + ---------- + key : str, optional + The key identifying the JSON data to retrieve. Defaults to an empty string. + prototype : BufferPrototype + The buffer prototype to use for reading the data. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + Warnings + -------- + Do not call this method from async functions. Use ``get_json_async()`` instead + to avoid blocking the event loop. + + See Also + -------- + get_json_async : Asynchronous version of this method. + get_bytes : Synchronous method for retrieving raw bytes without parsing. + + Examples + -------- + >>> store = MemoryStore() + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> data = store.get_json("zarr.json", prototype=default_buffer_prototype()) + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + + return sync(self.get_json_async(key, prototype=prototype, byte_range=byte_range)) + @abstractmethod async def get_partial_values( self, @@ -278,7 +489,7 @@ async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: """ Insert multiple (key, value) pairs into storage. """ - await gather(*starmap(self.set, values)) + await asyncio.gather(*starmap(self.set, values)) @property def supports_consolidated_metadata(self) -> bool: diff --git a/src/zarr/storage/_common.py b/src/zarr/storage/_common.py index d762097cc3..b8dcaff3be 100644 --- a/src/zarr/storage/_common.py +++ b/src/zarr/storage/_common.py @@ -228,6 +228,113 @@ async def is_empty(self) -> bool: """ return await self.store.is_empty(self.path) + async def get_bytes_async( + self, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> bytes: + """ + Retrieve raw bytes from the store path asynchronously. + + This is a convenience method that wraps ``get()`` and converts the result + to bytes. The ``prototype`` parameter is optional and defaults to the + standard buffer prototype. + + Parameters + ---------- + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. + + Returns + ------- + bytes + The raw bytes stored at this path. + + Raises + ------ + FileNotFoundError + If the path does not exist in the store. + + See Also + -------- + get : Lower-level method that returns a Buffer object. + get_json_async : Asynchronous method for retrieving and parsing JSON data. + + Examples + -------- + >>> store = await MemoryStore.open() + >>> path = StorePath(store, "data") + >>> await path.set(Buffer.from_bytes(b"hello world")) + >>> data = await path.get_bytes_async() + >>> print(data) + b'hello world' + """ + if prototype is None: + prototype = default_buffer_prototype() + return await self.store.get_bytes_async( + self.path, prototype=prototype, byte_range=byte_range + ) + + async def get_json_async( + self, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Any: + """ + Retrieve and parse JSON data from the store path asynchronously. + + This is a convenience method that retrieves bytes from the store and + parses them as JSON. The ``prototype`` parameter is optional and defaults + to the standard buffer prototype. + + Parameters + ---------- + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the path does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + See Also + -------- + get_bytes_async : Method for retrieving raw bytes without parsing. + get : Lower-level method that returns a Buffer object. + + Examples + -------- + >>> store = await MemoryStore.open() + >>> path = StorePath(store, "zarr.json") + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> await path.set(Buffer.from_bytes(json.dumps(metadata).encode())) + >>> data = await path.get_json_async() + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + if prototype is None: + prototype = default_buffer_prototype() + return await self.store.get_json_async( + self.path, prototype=prototype, byte_range=byte_range + ) + def __truediv__(self, other: str) -> StorePath: """Combine this store path with another path""" return self.__class__(self.store, _dereference_path(self.path, other)) diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index f64da71bb4..13c86a2f22 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -8,7 +8,7 @@ import sys import uuid from pathlib import Path -from typing import TYPE_CHECKING, BinaryIO, Literal, Self +from typing import TYPE_CHECKING, Any, BinaryIO, Literal, Self from zarr.abc.store import ( ByteRequest, @@ -306,6 +306,236 @@ async def list_dir(self, prefix: str) -> AsyncIterator[str]: except (FileNotFoundError, NotADirectoryError): pass + async def get_bytes_async( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> bytes: + """ + Retrieve raw bytes from the local store asynchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_bytes_async`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + See Also + -------- + Store.get_bytes_async : Base implementation with full documentation. + get_bytes : Synchronous version of this method. + + Examples + -------- + >>> store = await LocalStore.open("data") + >>> await store.set("data", Buffer.from_bytes(b"hello")) + >>> # No need to specify prototype for LocalStore + >>> data = await store.get_bytes_async("data") + >>> print(data) + b'hello' + """ + if prototype is None: + prototype = default_buffer_prototype() + return await super().get_bytes_async(key, prototype=prototype, byte_range=byte_range) + + def get_bytes( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> bytes: + """ + Retrieve raw bytes from the local store synchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_bytes`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + Warnings + -------- + Do not call this method from async functions. Use ``get_bytes_async()`` instead. + + See Also + -------- + Store.get_bytes : Base implementation with full documentation. + get_bytes_async : Asynchronous version of this method. + + Examples + -------- + >>> store = LocalStore("data") + >>> store.set("data", Buffer.from_bytes(b"hello")) + >>> # No need to specify prototype for LocalStore + >>> data = store.get_bytes("data") + >>> print(data) + b'hello' + """ + if prototype is None: + prototype = default_buffer_prototype() + return super().get_bytes(key, prototype=prototype, byte_range=byte_range) + + async def get_json_async( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Any: + """ + Retrieve and parse JSON data from the local store asynchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_json_async`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the JSON data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + See Also + -------- + Store.get_json_async : Base implementation with full documentation. + get_json : Synchronous version of this method. + get_bytes_async : Method for retrieving raw bytes without parsing. + + Examples + -------- + >>> store = await LocalStore.open("data") + >>> import json + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> await store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> # No need to specify prototype for LocalStore + >>> data = await store.get_json_async("zarr.json") + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + if prototype is None: + prototype = default_buffer_prototype() + return await super().get_json_async(key, prototype=prototype, byte_range=byte_range) + + def get_json( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Any: + """ + Retrieve and parse JSON data from the local store synchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_json`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the JSON data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + Warnings + -------- + Do not call this method from async functions. Use ``get_json_async()`` instead. + + See Also + -------- + Store.get_json : Base implementation with full documentation. + get_json_async : Asynchronous version of this method. + get_bytes : Method for retrieving raw bytes without parsing. + + Examples + -------- + >>> store = LocalStore("data") + >>> import json + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> # No need to specify prototype for LocalStore + >>> data = store.get_json("zarr.json") + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + if prototype is None: + prototype = default_buffer_prototype() + return super().get_json(key, prototype=prototype, byte_range=byte_range) + async def move(self, dest_root: Path | str) -> None: """ Move the store to another path. The old root directory is deleted. diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index 904be922d7..b56771f62a 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -1,7 +1,7 @@ from __future__ import annotations from logging import getLogger -from typing import TYPE_CHECKING, Self +from typing import TYPE_CHECKING, Any, Self from zarr.abc.store import ByteRequest, Store from zarr.core.buffer import Buffer, gpu @@ -175,6 +175,236 @@ async def list_dir(self, prefix: str) -> AsyncIterator[str]: for key in keys_unique: yield key + async def get_bytes_async( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> bytes: + """ + Retrieve raw bytes from the memory store asynchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_bytes_async`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + See Also + -------- + Store.get_bytes_async : Base implementation with full documentation. + get_bytes : Synchronous version of this method. + + Examples + -------- + >>> store = await MemoryStore.open() + >>> await store.set("data", Buffer.from_bytes(b"hello")) + >>> # No need to specify prototype for MemoryStore + >>> data = await store.get_bytes_async("data") + >>> print(data) + b'hello' + """ + if prototype is None: + prototype = default_buffer_prototype() + return await super().get_bytes_async(key, prototype=prototype, byte_range=byte_range) + + def get_bytes( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> bytes: + """ + Retrieve raw bytes from the memory store synchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_bytes`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + Warnings + -------- + Do not call this method from async functions. Use ``get_bytes_async()`` instead. + + See Also + -------- + Store.get_bytes : Base implementation with full documentation. + get_bytes_async : Asynchronous version of this method. + + Examples + -------- + >>> store = MemoryStore() + >>> store.set("data", Buffer.from_bytes(b"hello")) + >>> # No need to specify prototype for MemoryStore + >>> data = store.get_bytes("data") + >>> print(data) + b'hello' + """ + if prototype is None: + prototype = default_buffer_prototype() + return super().get_bytes(key, prototype=prototype, byte_range=byte_range) + + async def get_json_async( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Any: + """ + Retrieve and parse JSON data from the memory store asynchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_json_async`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the JSON data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + See Also + -------- + Store.get_json_async : Base implementation with full documentation. + get_json : Synchronous version of this method. + get_bytes_async : Method for retrieving raw bytes without parsing. + + Examples + -------- + >>> store = await MemoryStore.open() + >>> import json + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> await store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> # No need to specify prototype for MemoryStore + >>> data = await store.get_json_async("zarr.json") + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + if prototype is None: + prototype = default_buffer_prototype() + return await super().get_json_async(key, prototype=prototype, byte_range=byte_range) + + def get_json( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Any: + """ + Retrieve and parse JSON data from the memory store synchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_json`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the JSON data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + Warnings + -------- + Do not call this method from async functions. Use ``get_json_async()`` instead. + + See Also + -------- + Store.get_json : Base implementation with full documentation. + get_json_async : Asynchronous version of this method. + get_bytes : Method for retrieving raw bytes without parsing. + + Examples + -------- + >>> store = MemoryStore() + >>> import json + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> # No need to specify prototype for MemoryStore + >>> data = store.get_json("zarr.json") + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + if prototype is None: + prototype = default_buffer_prototype() + return super().get_json(key, prototype=prototype, byte_range=byte_range) + class GpuMemoryStore(MemoryStore): """ diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index ad3b80da41..bee28639a2 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import json import pickle from abc import abstractmethod from typing import TYPE_CHECKING, Generic, TypeVar @@ -23,7 +24,7 @@ SuffixByteRequest, ) from zarr.core.buffer import Buffer, default_buffer_prototype -from zarr.core.sync import _collect_aiterator +from zarr.core.sync import _collect_aiterator, sync from zarr.storage._utils import _normalize_byte_range_index from zarr.testing.utils import assert_bytes_equal @@ -526,6 +527,44 @@ async def test_set_if_not_exists(self, store: S) -> None: result = await store.get("k2", default_buffer_prototype()) assert result == new + async def test_get_bytes_async(self, store: S) -> None: + """ + Test that the get_bytes_async method reads bytes. + """ + data = b"hello world" + key = "zarr.json" + await self.set(store, key, self.buffer_cls.from_bytes(data)) + assert await store.get_bytes_async(key, prototype=default_buffer_prototype()) == data + + def test_get_bytes_sync(self, store: S) -> None: + """ + Test that the get_bytes method reads bytes. + """ + data = b"hello world" + key = "zarr.json" + sync(self.set(store, key, self.buffer_cls.from_bytes(data))) + assert store.get_bytes(key, prototype=default_buffer_prototype()) == data + + async def test_get_json_async(self, store: S) -> None: + """ + Test that the get_bytes_async method reads json. + """ + data = {"foo": "bar"} + data_bytes = json.dumps(data).encode("utf-8") + key = "zarr.json" + await self.set(store, key, self.buffer_cls.from_bytes(data_bytes)) + assert await store.get_json_async(key, prototype=default_buffer_prototype()) == data + + def test_get_json_sync(self, store: S) -> None: + """ + Test that the get_json method reads json. + """ + data = {"foo": "bar"} + data_bytes = json.dumps(data).encode("utf-8") + key = "zarr.json" + sync(self.set(store, key, self.buffer_cls.from_bytes(data_bytes))) + assert store.get_json(key, prototype=default_buffer_prototype()) == data + class LatencyStore(WrapperStore[Store]): """ diff --git a/tests/test_store/test_local.py b/tests/test_store/test_local.py index 6756bc83d9..35d48e3f95 100644 --- a/tests/test_store/test_local.py +++ b/tests/test_store/test_local.py @@ -150,3 +150,85 @@ def test_atomic_write_exclusive_preexisting(tmp_path: pathlib.Path) -> None: f.write(b"abc") assert path.read_bytes() == b"xyz" assert list(path.parent.iterdir()) == [path] # no temp files + + +async def test_get_bytes_with_prototype_none(tmp_path: pathlib.Path) -> None: + """Test that get_bytes_async works with prototype=None.""" + from zarr.core.buffer import cpu + from zarr.core.buffer.core import default_buffer_prototype + + store = await LocalStore.open(root=tmp_path) + data = b"hello world" + key = "test_key" + await store.set(key, cpu.Buffer.from_bytes(data)) + + # Test with None (default) + result_none = await store.get_bytes_async(key) + assert result_none == data + + # Test with explicit prototype + result_proto = await store.get_bytes_async(key, prototype=default_buffer_prototype()) + assert result_proto == data + + +def test_get_bytes_sync_with_prototype_none(tmp_path: pathlib.Path) -> None: + """Test that get_bytes works with prototype=None.""" + from zarr.core.buffer import cpu + from zarr.core.buffer.core import default_buffer_prototype + from zarr.core.sync import sync + + store = sync(LocalStore.open(root=tmp_path)) + data = b"hello world" + key = "test_key" + sync(store.set(key, cpu.Buffer.from_bytes(data))) + + # Test with None (default) + result_none = store.get_bytes(key) + assert result_none == data + + # Test with explicit prototype + result_proto = store.get_bytes(key, prototype=default_buffer_prototype()) + assert result_proto == data + + +async def test_get_json_with_prototype_none(tmp_path: pathlib.Path) -> None: + """Test that get_json_async works with prototype=None.""" + import json + + from zarr.core.buffer import cpu + from zarr.core.buffer.core import default_buffer_prototype + + store = await LocalStore.open(root=tmp_path) + data = {"foo": "bar", "number": 42} + key = "test.json" + await store.set(key, cpu.Buffer.from_bytes(json.dumps(data).encode())) + + # Test with None (default) + result_none = await store.get_json_async(key) + assert result_none == data + + # Test with explicit prototype + result_proto = await store.get_json_async(key, prototype=default_buffer_prototype()) + assert result_proto == data + + +def test_get_json_sync_with_prototype_none(tmp_path: pathlib.Path) -> None: + """Test that get_json works with prototype=None.""" + import json + + from zarr.core.buffer import cpu + from zarr.core.buffer.core import default_buffer_prototype + from zarr.core.sync import sync + + store = sync(LocalStore.open(root=tmp_path)) + data = {"foo": "bar", "number": 42} + key = "test.json" + sync(store.set(key, cpu.Buffer.from_bytes(json.dumps(data).encode()))) + + # Test with None (default) + result_none = store.get_json(key) + assert result_none == data + + # Test with explicit prototype + result_proto = store.get_json(key, prototype=default_buffer_prototype()) + assert result_proto == data diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index 29fa9b2964..b56d9933d4 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -76,6 +76,76 @@ async def test_deterministic_size( np.testing.assert_array_equal(a[:3], 1) np.testing.assert_array_equal(a[3:], 0) + async def test_get_bytes_with_prototype_none(self, store: MemoryStore) -> None: + """Test that get_bytes_async works with prototype=None.""" + from zarr.core.buffer.core import default_buffer_prototype + + data = b"hello world" + key = "test_key" + await self.set(store, key, self.buffer_cls.from_bytes(data)) + + # Test with None (default) + result_none = await store.get_bytes_async(key) + assert result_none == data + + # Test with explicit prototype + result_proto = await store.get_bytes_async(key, prototype=default_buffer_prototype()) + assert result_proto == data + + def test_get_bytes_sync_with_prototype_none(self, store: MemoryStore) -> None: + """Test that get_bytes works with prototype=None.""" + from zarr.core.buffer.core import default_buffer_prototype + from zarr.core.sync import sync + + data = b"hello world" + key = "test_key" + sync(self.set(store, key, self.buffer_cls.from_bytes(data))) + + # Test with None (default) + result_none = store.get_bytes(key) + assert result_none == data + + # Test with explicit prototype + result_proto = store.get_bytes(key, prototype=default_buffer_prototype()) + assert result_proto == data + + async def test_get_json_with_prototype_none(self, store: MemoryStore) -> None: + """Test that get_json_async works with prototype=None.""" + import json + + from zarr.core.buffer.core import default_buffer_prototype + + data = {"foo": "bar", "number": 42} + key = "test.json" + await self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode())) + + # Test with None (default) + result_none = await store.get_json_async(key) + assert result_none == data + + # Test with explicit prototype + result_proto = await store.get_json_async(key, prototype=default_buffer_prototype()) + assert result_proto == data + + def test_get_json_sync_with_prototype_none(self, store: MemoryStore) -> None: + """Test that get_json works with prototype=None.""" + import json + + from zarr.core.buffer.core import default_buffer_prototype + from zarr.core.sync import sync + + data = {"foo": "bar", "number": 42} + key = "test.json" + sync(self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode()))) + + # Test with None (default) + result_none = store.get_json(key) + assert result_none == data + + # Test with explicit prototype + result_proto = store.get_json(key, prototype=default_buffer_prototype()) + assert result_proto == data + # TODO: fix this warning @pytest.mark.filterwarnings("ignore:Unclosed client session:ResourceWarning") From 7d26b8ee4d33c6c784e42784e1acf37c5836dd8e Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sun, 4 Jan 2026 18:02:29 +0100 Subject: [PATCH 02/13] check for FileNotFoundError when a key is missing --- src/zarr/testing/store.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index bee28639a2..30ff376fb0 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -535,6 +535,8 @@ async def test_get_bytes_async(self, store: S) -> None: key = "zarr.json" await self.set(store, key, self.buffer_cls.from_bytes(data)) assert await store.get_bytes_async(key, prototype=default_buffer_prototype()) == data + with pytest.raises(FileNotFoundError): + await store.get_bytes_async("nonexistent_key", prototype=default_buffer_prototype()) def test_get_bytes_sync(self, store: S) -> None: """ From 971c3e4fb6dd6c895e49c2763be5c1b9164b9114 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sun, 4 Jan 2026 18:02:39 +0100 Subject: [PATCH 03/13] remove storepath methods --- src/zarr/storage/_common.py | 107 ------------------------------------ 1 file changed, 107 deletions(-) diff --git a/src/zarr/storage/_common.py b/src/zarr/storage/_common.py index b8dcaff3be..d762097cc3 100644 --- a/src/zarr/storage/_common.py +++ b/src/zarr/storage/_common.py @@ -228,113 +228,6 @@ async def is_empty(self) -> bool: """ return await self.store.is_empty(self.path) - async def get_bytes_async( - self, - prototype: BufferPrototype | None = None, - byte_range: ByteRequest | None = None, - ) -> bytes: - """ - Retrieve raw bytes from the store path asynchronously. - - This is a convenience method that wraps ``get()`` and converts the result - to bytes. The ``prototype`` parameter is optional and defaults to the - standard buffer prototype. - - Parameters - ---------- - prototype : BufferPrototype, optional - The buffer prototype to use for reading the data. If None, uses - ``default_buffer_prototype()``. - byte_range : ByteRequest, optional - If specified, only retrieve a portion of the stored data. - Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. - - Returns - ------- - bytes - The raw bytes stored at this path. - - Raises - ------ - FileNotFoundError - If the path does not exist in the store. - - See Also - -------- - get : Lower-level method that returns a Buffer object. - get_json_async : Asynchronous method for retrieving and parsing JSON data. - - Examples - -------- - >>> store = await MemoryStore.open() - >>> path = StorePath(store, "data") - >>> await path.set(Buffer.from_bytes(b"hello world")) - >>> data = await path.get_bytes_async() - >>> print(data) - b'hello world' - """ - if prototype is None: - prototype = default_buffer_prototype() - return await self.store.get_bytes_async( - self.path, prototype=prototype, byte_range=byte_range - ) - - async def get_json_async( - self, - prototype: BufferPrototype | None = None, - byte_range: ByteRequest | None = None, - ) -> Any: - """ - Retrieve and parse JSON data from the store path asynchronously. - - This is a convenience method that retrieves bytes from the store and - parses them as JSON. The ``prototype`` parameter is optional and defaults - to the standard buffer prototype. - - Parameters - ---------- - prototype : BufferPrototype, optional - The buffer prototype to use for reading the data. If None, uses - ``default_buffer_prototype()``. - byte_range : ByteRequest, optional - If specified, only retrieve a portion of the stored data. - Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. - Note: Using byte ranges with JSON may result in invalid JSON. - - Returns - ------- - Any - The parsed JSON data. This follows the behavior of ``json.loads()`` and - can be any JSON-serializable type: dict, list, str, int, float, bool, or None. - - Raises - ------ - FileNotFoundError - If the path does not exist in the store. - json.JSONDecodeError - If the stored data is not valid JSON. - - See Also - -------- - get_bytes_async : Method for retrieving raw bytes without parsing. - get : Lower-level method that returns a Buffer object. - - Examples - -------- - >>> store = await MemoryStore.open() - >>> path = StorePath(store, "zarr.json") - >>> metadata = {"zarr_format": 3, "node_type": "array"} - >>> await path.set(Buffer.from_bytes(json.dumps(metadata).encode())) - >>> data = await path.get_json_async() - >>> print(data) - {'zarr_format': 3, 'node_type': 'array'} - """ - if prototype is None: - prototype = default_buffer_prototype() - return await self.store.get_json_async( - self.path, prototype=prototype, byte_range=byte_range - ) - def __truediv__(self, other: str) -> StorePath: """Combine this store path with another path""" return self.__class__(self.store, _dereference_path(self.path, other)) From d70a5e5277ca7a225d9fb0ac22945b5e1fcb8cc3 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sun, 4 Jan 2026 18:24:39 +0100 Subject: [PATCH 04/13] changelog --- changes/3638.feature.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/3638.feature.md diff --git a/changes/3638.feature.md b/changes/3638.feature.md new file mode 100644 index 0000000000..ad2276fd51 --- /dev/null +++ b/changes/3638.feature.md @@ -0,0 +1 @@ +Add methods for reading stored objects as bytes and JSON-decoded bytes to store classes. \ No newline at end of file From a21305887b7287d957811afb0a8ec8f894d8cfc7 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 8 Jan 2026 10:37:23 +0100 Subject: [PATCH 05/13] rename methods --- src/zarr/abc/store.py | 45 +++++++++++++------------- src/zarr/storage/_local.py | 26 +++++++-------- src/zarr/storage/_memory.py | 26 +++++++-------- src/zarr/testing/store.py | 12 +++---- tests/test_store/test_local.py | 56 ++++++++++++++------------------- tests/test_store/test_memory.py | 18 +++++------ 6 files changed, 85 insertions(+), 98 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 7d0589c836..e685e4b3b0 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -209,7 +209,7 @@ async def get( """ ... - async def get_bytes_async( + async def get_bytes( self, key: str, *, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> bytes: """ @@ -242,7 +242,7 @@ async def get_bytes_async( -------- get : Lower-level method that returns a Buffer object. get_bytes : Synchronous version of this method. - get_json_async : Asynchronous method for retrieving and parsing JSON data. + get_json : Asynchronous method for retrieving and parsing JSON data. Examples -------- @@ -257,7 +257,7 @@ async def get_bytes_async( raise FileNotFoundError(key) return buffer.to_bytes() - def get_bytes( + def get_bytes_sync( self, key: str = "", *, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> bytes: """ @@ -289,7 +289,7 @@ def get_bytes( Warnings -------- - Do not call this method from async functions. Use ``get_bytes_async()`` instead + Do not call this method from async functions. Use ``get_bytes()`` instead to avoid blocking the event loop. See Also @@ -300,23 +300,22 @@ def get_bytes( Examples -------- >>> store = MemoryStore() - >>> store.set("data", Buffer.from_bytes(b"hello world")) - >>> data = store.get_bytes("data", prototype=default_buffer_prototype()) + >>> await store.set("data", Buffer.from_bytes(b"hello world")) + >>> data = store.get_bytes_sync("data", prototype=default_buffer_prototype()) >>> print(data) b'hello world' """ - return sync(self.get_bytes_async(key, prototype=prototype, byte_range=byte_range)) + return sync(self.get_bytes(key, prototype=prototype, byte_range=byte_range)) - async def get_json_async( + async def get_json( self, key: str, *, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Any: """ Retrieve and parse JSON data from the store asynchronously. This is a convenience method that retrieves bytes from the store and - parses them as JSON. Commonly used for reading Zarr metadata files - like ``zarr.json``. + parses them as JSON. Parameters ---------- @@ -344,31 +343,29 @@ async def get_json_async( See Also -------- - get_bytes_async : Method for retrieving raw bytes without parsing. - get_json : Synchronous version of this method. + get_bytes : Method for retrieving raw bytes. + get_json_sync : Synchronous version of this method. Examples -------- >>> store = await MemoryStore.open() >>> metadata = {"zarr_format": 3, "node_type": "array"} >>> await store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) - >>> data = await store.get_json_async("zarr.json", prototype=default_buffer_prototype()) + >>> data = await store.get_json("zarr.json", prototype=default_buffer_prototype()) >>> print(data) {'zarr_format': 3, 'node_type': 'array'} """ - return json.loads( - await self.get_bytes_async(key, prototype=prototype, byte_range=byte_range) - ) + return json.loads(await self.get_bytes(key, prototype=prototype, byte_range=byte_range)) - def get_json( + def get_json_sync( self, key: str = "", *, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Any: """ Retrieve and parse JSON data from the store synchronously. - This is a synchronous wrapper around ``get_json_async()``. It should only - be called from non-async code. For async contexts, use ``get_json_async()`` + This is a synchronous wrapper around ``get_json()``. It should only + be called from non-async code. For async contexts, use ``get_json()`` instead. Parameters @@ -397,25 +394,25 @@ def get_json( Warnings -------- - Do not call this method from async functions. Use ``get_json_async()`` instead + Do not call this method from async functions. Use ``get_json()`` instead to avoid blocking the event loop. See Also -------- - get_json_async : Asynchronous version of this method. - get_bytes : Synchronous method for retrieving raw bytes without parsing. + get_json : Asynchronous version of this method. + get_bytes_sync : Synchronous method for retrieving raw bytes without parsing. Examples -------- >>> store = MemoryStore() >>> metadata = {"zarr_format": 3, "node_type": "array"} >>> store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) - >>> data = store.get_json("zarr.json", prototype=default_buffer_prototype()) + >>> data = store.get_json_sync("zarr.json", prototype=default_buffer_prototype()) >>> print(data) {'zarr_format': 3, 'node_type': 'array'} """ - return sync(self.get_json_async(key, prototype=prototype, byte_range=byte_range)) + return sync(self.get_json(key, prototype=prototype, byte_range=byte_range)) @abstractmethod async def get_partial_values( diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index 13c86a2f22..08681f2630 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -306,7 +306,7 @@ async def list_dir(self, prefix: str) -> AsyncIterator[str]: except (FileNotFoundError, NotADirectoryError): pass - async def get_bytes_async( + async def get_bytes( self, key: str = "", *, @@ -356,9 +356,9 @@ async def get_bytes_async( """ if prototype is None: prototype = default_buffer_prototype() - return await super().get_bytes_async(key, prototype=prototype, byte_range=byte_range) + return await super().get_bytes(key, prototype=prototype, byte_range=byte_range) - def get_bytes( + def get_bytes_sync( self, key: str = "", *, @@ -412,9 +412,9 @@ def get_bytes( """ if prototype is None: prototype = default_buffer_prototype() - return super().get_bytes(key, prototype=prototype, byte_range=byte_range) + return super().get_bytes_sync(key, prototype=prototype, byte_range=byte_range) - async def get_json_async( + async def get_json( self, key: str = "", *, @@ -425,7 +425,7 @@ async def get_json_async( Retrieve and parse JSON data from the local store asynchronously. This is a convenience override that makes the ``prototype`` parameter optional - by defaulting to the standard buffer prototype. See the base ``Store.get_json_async`` + by defaulting to the standard buffer prototype. See the base ``Store.get_json`` for full documentation. Parameters @@ -454,7 +454,7 @@ async def get_json_async( See Also -------- - Store.get_json_async : Base implementation with full documentation. + Store.get_json : Base implementation with full documentation. get_json : Synchronous version of this method. get_bytes_async : Method for retrieving raw bytes without parsing. @@ -465,15 +465,15 @@ async def get_json_async( >>> metadata = {"zarr_format": 3, "node_type": "array"} >>> await store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) >>> # No need to specify prototype for LocalStore - >>> data = await store.get_json_async("zarr.json") + >>> data = await store.get_json("zarr.json") >>> print(data) {'zarr_format': 3, 'node_type': 'array'} """ if prototype is None: prototype = default_buffer_prototype() - return await super().get_json_async(key, prototype=prototype, byte_range=byte_range) + return await super().get_json(key, prototype=prototype, byte_range=byte_range) - def get_json( + def get_json_sync( self, key: str = "", *, @@ -513,12 +513,12 @@ def get_json( Warnings -------- - Do not call this method from async functions. Use ``get_json_async()`` instead. + Do not call this method from async functions. Use ``get_json()`` instead. See Also -------- Store.get_json : Base implementation with full documentation. - get_json_async : Asynchronous version of this method. + get_json : Asynchronous version of this method. get_bytes : Method for retrieving raw bytes without parsing. Examples @@ -534,7 +534,7 @@ def get_json( """ if prototype is None: prototype = default_buffer_prototype() - return super().get_json(key, prototype=prototype, byte_range=byte_range) + return super().get_json_sync(key, prototype=prototype, byte_range=byte_range) async def move(self, dest_root: Path | str) -> None: """ diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index b56771f62a..5a2593eb25 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -175,7 +175,7 @@ async def list_dir(self, prefix: str) -> AsyncIterator[str]: for key in keys_unique: yield key - async def get_bytes_async( + async def get_bytes( self, key: str = "", *, @@ -225,9 +225,9 @@ async def get_bytes_async( """ if prototype is None: prototype = default_buffer_prototype() - return await super().get_bytes_async(key, prototype=prototype, byte_range=byte_range) + return await super().get_bytes(key, prototype=prototype, byte_range=byte_range) - def get_bytes( + def get_bytes_sync( self, key: str = "", *, @@ -281,9 +281,9 @@ def get_bytes( """ if prototype is None: prototype = default_buffer_prototype() - return super().get_bytes(key, prototype=prototype, byte_range=byte_range) + return super().get_bytes_sync(key, prototype=prototype, byte_range=byte_range) - async def get_json_async( + async def get_json( self, key: str = "", *, @@ -294,7 +294,7 @@ async def get_json_async( Retrieve and parse JSON data from the memory store asynchronously. This is a convenience override that makes the ``prototype`` parameter optional - by defaulting to the standard buffer prototype. See the base ``Store.get_json_async`` + by defaulting to the standard buffer prototype. See the base ``Store.get_json`` for full documentation. Parameters @@ -323,7 +323,7 @@ async def get_json_async( See Also -------- - Store.get_json_async : Base implementation with full documentation. + Store.get_json : Base implementation with full documentation. get_json : Synchronous version of this method. get_bytes_async : Method for retrieving raw bytes without parsing. @@ -334,15 +334,15 @@ async def get_json_async( >>> metadata = {"zarr_format": 3, "node_type": "array"} >>> await store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) >>> # No need to specify prototype for MemoryStore - >>> data = await store.get_json_async("zarr.json") + >>> data = await store.get_json("zarr.json") >>> print(data) {'zarr_format': 3, 'node_type': 'array'} """ if prototype is None: prototype = default_buffer_prototype() - return await super().get_json_async(key, prototype=prototype, byte_range=byte_range) + return await super().get_json(key, prototype=prototype, byte_range=byte_range) - def get_json( + def get_json_sync( self, key: str = "", *, @@ -382,12 +382,12 @@ def get_json( Warnings -------- - Do not call this method from async functions. Use ``get_json_async()`` instead. + Do not call this method from async functions. Use ``get_json()`` instead. See Also -------- Store.get_json : Base implementation with full documentation. - get_json_async : Asynchronous version of this method. + get_json : Asynchronous version of this method. get_bytes : Method for retrieving raw bytes without parsing. Examples @@ -403,7 +403,7 @@ def get_json( """ if prototype is None: prototype = default_buffer_prototype() - return super().get_json(key, prototype=prototype, byte_range=byte_range) + return super().get_json_sync(key, prototype=prototype, byte_range=byte_range) class GpuMemoryStore(MemoryStore): diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 30ff376fb0..f0cb6dd48f 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -534,9 +534,9 @@ async def test_get_bytes_async(self, store: S) -> None: data = b"hello world" key = "zarr.json" await self.set(store, key, self.buffer_cls.from_bytes(data)) - assert await store.get_bytes_async(key, prototype=default_buffer_prototype()) == data + assert await store.get_bytes(key, prototype=default_buffer_prototype()) == data with pytest.raises(FileNotFoundError): - await store.get_bytes_async("nonexistent_key", prototype=default_buffer_prototype()) + await store.get_bytes("nonexistent_key", prototype=default_buffer_prototype()) def test_get_bytes_sync(self, store: S) -> None: """ @@ -545,9 +545,9 @@ def test_get_bytes_sync(self, store: S) -> None: data = b"hello world" key = "zarr.json" sync(self.set(store, key, self.buffer_cls.from_bytes(data))) - assert store.get_bytes(key, prototype=default_buffer_prototype()) == data + assert store.get_bytes_sync(key, prototype=default_buffer_prototype()) == data - async def test_get_json_async(self, store: S) -> None: + async def test_get_json(self, store: S) -> None: """ Test that the get_bytes_async method reads json. """ @@ -555,7 +555,7 @@ async def test_get_json_async(self, store: S) -> None: data_bytes = json.dumps(data).encode("utf-8") key = "zarr.json" await self.set(store, key, self.buffer_cls.from_bytes(data_bytes)) - assert await store.get_json_async(key, prototype=default_buffer_prototype()) == data + assert await store.get_json(key, prototype=default_buffer_prototype()) == data def test_get_json_sync(self, store: S) -> None: """ @@ -565,7 +565,7 @@ def test_get_json_sync(self, store: S) -> None: data_bytes = json.dumps(data).encode("utf-8") key = "zarr.json" sync(self.set(store, key, self.buffer_cls.from_bytes(data_bytes))) - assert store.get_json(key, prototype=default_buffer_prototype()) == data + assert store.get_json_sync(key, prototype=default_buffer_prototype()) == data class LatencyStore(WrapperStore[Store]): diff --git a/tests/test_store/test_local.py b/tests/test_store/test_local.py index 35d48e3f95..d6a97110ad 100644 --- a/tests/test_store/test_local.py +++ b/tests/test_store/test_local.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import pathlib import re @@ -9,6 +10,8 @@ import zarr from zarr import create_array from zarr.core.buffer import Buffer, cpu +from zarr.core.buffer.core import BufferPrototype, default_buffer_prototype +from zarr.core.sync import sync from zarr.storage import LocalStore from zarr.storage._local import _atomic_write from zarr.testing.store import StoreTests @@ -153,9 +156,8 @@ def test_atomic_write_exclusive_preexisting(tmp_path: pathlib.Path) -> None: async def test_get_bytes_with_prototype_none(tmp_path: pathlib.Path) -> None: - """Test that get_bytes_async works with prototype=None.""" + """Test that get_bytes works with prototype=None.""" from zarr.core.buffer import cpu - from zarr.core.buffer.core import default_buffer_prototype store = await LocalStore.open(root=tmp_path) data = b"hello world" @@ -163,18 +165,17 @@ async def test_get_bytes_with_prototype_none(tmp_path: pathlib.Path) -> None: await store.set(key, cpu.Buffer.from_bytes(data)) # Test with None (default) - result_none = await store.get_bytes_async(key) + result_none = await store.get_bytes(key) assert result_none == data # Test with explicit prototype - result_proto = await store.get_bytes_async(key, prototype=default_buffer_prototype()) + result_proto = await store.get_bytes(key, prototype=default_buffer_prototype()) assert result_proto == data def test_get_bytes_sync_with_prototype_none(tmp_path: pathlib.Path) -> None: - """Test that get_bytes works with prototype=None.""" + """Test that get_bytes_sync works with prototype=None.""" from zarr.core.buffer import cpu - from zarr.core.buffer.core import default_buffer_prototype from zarr.core.sync import sync store = sync(LocalStore.open(root=tmp_path)) @@ -183,20 +184,19 @@ def test_get_bytes_sync_with_prototype_none(tmp_path: pathlib.Path) -> None: sync(store.set(key, cpu.Buffer.from_bytes(data))) # Test with None (default) - result_none = store.get_bytes(key) + result_none = store.get_bytes_sync(key) assert result_none == data # Test with explicit prototype - result_proto = store.get_bytes(key, prototype=default_buffer_prototype()) + result_proto = store.get_bytes_sync(key, prototype=default_buffer_prototype()) assert result_proto == data -async def test_get_json_with_prototype_none(tmp_path: pathlib.Path) -> None: - """Test that get_json_async works with prototype=None.""" - import json - - from zarr.core.buffer import cpu - from zarr.core.buffer.core import default_buffer_prototype +@pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) +async def test_get_json_with_prototype_none( + tmp_path: pathlib.Path, buffer_cls: None | BufferPrototype +) -> None: + """Test that get_json works with prototype=None.""" store = await LocalStore.open(root=tmp_path) data = {"foo": "bar", "number": 42} @@ -204,21 +204,15 @@ async def test_get_json_with_prototype_none(tmp_path: pathlib.Path) -> None: await store.set(key, cpu.Buffer.from_bytes(json.dumps(data).encode())) # Test with None (default) - result_none = await store.get_json_async(key) - assert result_none == data + result = await store.get_json(key, prototype=buffer_cls) + assert result == data - # Test with explicit prototype - result_proto = await store.get_json_async(key, prototype=default_buffer_prototype()) - assert result_proto == data - -def test_get_json_sync_with_prototype_none(tmp_path: pathlib.Path) -> None: - """Test that get_json works with prototype=None.""" - import json - - from zarr.core.buffer import cpu - from zarr.core.buffer.core import default_buffer_prototype - from zarr.core.sync import sync +@pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) +def test_get_json_sync_with_prototype( + tmp_path: pathlib.Path, buffer_cls: None | BufferPrototype +) -> None: + """Test that get_json_sync works with prototype=None.""" store = sync(LocalStore.open(root=tmp_path)) data = {"foo": "bar", "number": 42} @@ -226,9 +220,5 @@ def test_get_json_sync_with_prototype_none(tmp_path: pathlib.Path) -> None: sync(store.set(key, cpu.Buffer.from_bytes(json.dumps(data).encode()))) # Test with None (default) - result_none = store.get_json(key) - assert result_none == data - - # Test with explicit prototype - result_proto = store.get_json(key, prototype=default_buffer_prototype()) - assert result_proto == data + result = store.get_json_sync(key, prototype=buffer_cls) + assert result == data diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index b56d9933d4..c47c1adb12 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -85,11 +85,11 @@ async def test_get_bytes_with_prototype_none(self, store: MemoryStore) -> None: await self.set(store, key, self.buffer_cls.from_bytes(data)) # Test with None (default) - result_none = await store.get_bytes_async(key) + result_none = await store.get_bytes(key) assert result_none == data # Test with explicit prototype - result_proto = await store.get_bytes_async(key, prototype=default_buffer_prototype()) + result_proto = await store.get_bytes(key, prototype=default_buffer_prototype()) assert result_proto == data def test_get_bytes_sync_with_prototype_none(self, store: MemoryStore) -> None: @@ -102,15 +102,15 @@ def test_get_bytes_sync_with_prototype_none(self, store: MemoryStore) -> None: sync(self.set(store, key, self.buffer_cls.from_bytes(data))) # Test with None (default) - result_none = store.get_bytes(key) + result_none = store.get_bytes_sync(key) assert result_none == data # Test with explicit prototype - result_proto = store.get_bytes(key, prototype=default_buffer_prototype()) + result_proto = store.get_bytes_sync(key, prototype=default_buffer_prototype()) assert result_proto == data async def test_get_json_with_prototype_none(self, store: MemoryStore) -> None: - """Test that get_json_async works with prototype=None.""" + """Test that get_json works with prototype=None.""" import json from zarr.core.buffer.core import default_buffer_prototype @@ -120,11 +120,11 @@ async def test_get_json_with_prototype_none(self, store: MemoryStore) -> None: await self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode())) # Test with None (default) - result_none = await store.get_json_async(key) + result_none = await store.get_json(key) assert result_none == data # Test with explicit prototype - result_proto = await store.get_json_async(key, prototype=default_buffer_prototype()) + result_proto = await store.get_json(key, prototype=default_buffer_prototype()) assert result_proto == data def test_get_json_sync_with_prototype_none(self, store: MemoryStore) -> None: @@ -139,11 +139,11 @@ def test_get_json_sync_with_prototype_none(self, store: MemoryStore) -> None: sync(self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode()))) # Test with None (default) - result_none = store.get_json(key) + result_none = store.get_json_sync(key) assert result_none == data # Test with explicit prototype - result_proto = store.get_json(key, prototype=default_buffer_prototype()) + result_proto = store.get_json_sync(key, prototype=default_buffer_prototype()) assert result_proto == data From 38ff5172cc126d15abfa992ceee9cba81f7e9e3e Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 8 Jan 2026 11:02:07 +0100 Subject: [PATCH 06/13] continue renaming / test refactoring --- src/zarr/abc/store.py | 10 ++--- src/zarr/storage/_local.py | 22 ++++----- src/zarr/storage/_memory.py | 22 ++++----- src/zarr/testing/store.py | 8 ++-- tests/test_store/test_memory.py | 79 +++++++++++++-------------------- 5 files changed, 61 insertions(+), 80 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index e685e4b3b0..0e98777ff5 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -248,7 +248,7 @@ async def get_bytes( -------- >>> store = await MemoryStore.open() >>> await store.set("data", Buffer.from_bytes(b"hello world")) - >>> data = await store.get_bytes_async("data", prototype=default_buffer_prototype()) + >>> data = await store.get_bytes("data", prototype=default_buffer_prototype()) >>> print(data) b'hello world' """ @@ -263,8 +263,8 @@ def get_bytes_sync( """ Retrieve raw bytes from the store synchronously. - This is a synchronous wrapper around ``get_bytes_async()``. It should only - be called from non-async code. For async contexts, use ``get_bytes_async()`` + This is a synchronous wrapper around ``get_bytes()``. It should only + be called from non-async code. For async contexts, use ``get_bytes()`` instead. Parameters @@ -294,8 +294,8 @@ def get_bytes_sync( See Also -------- - get_bytes_async : Asynchronous version of this method. - get_json : Synchronous method for retrieving and parsing JSON data. + get_bytes : Asynchronous version of this method. + get_json_sync : Synchronous method for retrieving and parsing JSON data. Examples -------- diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index 08681f2630..9fb3f8b6ad 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -317,7 +317,7 @@ async def get_bytes( Retrieve raw bytes from the local store asynchronously. This is a convenience override that makes the ``prototype`` parameter optional - by defaulting to the standard buffer prototype. See the base ``Store.get_bytes_async`` + by defaulting to the standard buffer prototype. See the base ``Store.get_bytes`` for full documentation. Parameters @@ -342,15 +342,15 @@ async def get_bytes( See Also -------- - Store.get_bytes_async : Base implementation with full documentation. - get_bytes : Synchronous version of this method. + Store.get_bytes : Base implementation with full documentation. + get_bytes_sync : Synchronous version of this method. Examples -------- >>> store = await LocalStore.open("data") >>> await store.set("data", Buffer.from_bytes(b"hello")) >>> # No need to specify prototype for LocalStore - >>> data = await store.get_bytes_async("data") + >>> data = await store.get_bytes("data") >>> print(data) b'hello' """ @@ -394,12 +394,12 @@ def get_bytes_sync( Warnings -------- - Do not call this method from async functions. Use ``get_bytes_async()`` instead. + Do not call this method from async functions. Use ``get_bytes()`` instead. See Also -------- - Store.get_bytes : Base implementation with full documentation. - get_bytes_async : Asynchronous version of this method. + Store.get_bytes_sync : Base implementation with full documentation. + get_bytes : Asynchronous version of this method. Examples -------- @@ -455,8 +455,8 @@ async def get_json( See Also -------- Store.get_json : Base implementation with full documentation. - get_json : Synchronous version of this method. - get_bytes_async : Method for retrieving raw bytes without parsing. + get_json_sync : Synchronous version of this method. + get_bytes : Method for retrieving raw bytes without parsing. Examples -------- @@ -517,9 +517,9 @@ def get_json_sync( See Also -------- - Store.get_json : Base implementation with full documentation. + Store.get_json_sync : Base implementation with full documentation. get_json : Asynchronous version of this method. - get_bytes : Method for retrieving raw bytes without parsing. + get_bytes_sync : Method for retrieving raw bytes without parsing. Examples -------- diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index 5a2593eb25..1568cc6736 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -186,7 +186,7 @@ async def get_bytes( Retrieve raw bytes from the memory store asynchronously. This is a convenience override that makes the ``prototype`` parameter optional - by defaulting to the standard buffer prototype. See the base ``Store.get_bytes_async`` + by defaulting to the standard buffer prototype. See the base ``Store.get_bytes`` for full documentation. Parameters @@ -211,15 +211,15 @@ async def get_bytes( See Also -------- - Store.get_bytes_async : Base implementation with full documentation. - get_bytes : Synchronous version of this method. + Store.get_bytes : Base implementation with full documentation. + get_bytes_sync : Synchronous version of this method. Examples -------- >>> store = await MemoryStore.open() >>> await store.set("data", Buffer.from_bytes(b"hello")) >>> # No need to specify prototype for MemoryStore - >>> data = await store.get_bytes_async("data") + >>> data = await store.get_bytes("data") >>> print(data) b'hello' """ @@ -263,12 +263,12 @@ def get_bytes_sync( Warnings -------- - Do not call this method from async functions. Use ``get_bytes_async()`` instead. + Do not call this method from async functions. Use ``get_bytes()`` instead. See Also -------- - Store.get_bytes : Base implementation with full documentation. - get_bytes_async : Asynchronous version of this method. + Store.get_bytes_sync : Base implementation with full documentation. + get_bytes : Asynchronous version of this method. Examples -------- @@ -324,8 +324,8 @@ async def get_json( See Also -------- Store.get_json : Base implementation with full documentation. - get_json : Synchronous version of this method. - get_bytes_async : Method for retrieving raw bytes without parsing. + get_json_sync : Synchronous version of this method. + get_bytes : Method for retrieving raw bytes without parsing. Examples -------- @@ -386,9 +386,9 @@ def get_json_sync( See Also -------- - Store.get_json : Base implementation with full documentation. + Store.get_json_sync : Base implementation with full documentation. get_json : Asynchronous version of this method. - get_bytes : Method for retrieving raw bytes without parsing. + get_bytes_sync : Method for retrieving raw bytes without parsing. Examples -------- diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index f0cb6dd48f..a56061ae12 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -527,9 +527,9 @@ async def test_set_if_not_exists(self, store: S) -> None: result = await store.get("k2", default_buffer_prototype()) assert result == new - async def test_get_bytes_async(self, store: S) -> None: + async def test_get_bytes(self, store: S) -> None: """ - Test that the get_bytes_async method reads bytes. + Test that the get_bytes method reads bytes. """ data = b"hello world" key = "zarr.json" @@ -540,7 +540,7 @@ async def test_get_bytes_async(self, store: S) -> None: def test_get_bytes_sync(self, store: S) -> None: """ - Test that the get_bytes method reads bytes. + Test that the get_bytes_sync method reads bytes. """ data = b"hello world" key = "zarr.json" @@ -549,7 +549,7 @@ def test_get_bytes_sync(self, store: S) -> None: async def test_get_json(self, store: S) -> None: """ - Test that the get_bytes_async method reads json. + Test that the get_json method reads json. """ data = {"foo": "bar"} data_bytes = json.dumps(data).encode("utf-8") diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index c47c1adb12..96b7fe9845 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import re from typing import TYPE_CHECKING, Any @@ -9,12 +10,14 @@ import zarr from zarr.core.buffer import Buffer, cpu, gpu +from zarr.core.sync import sync from zarr.errors import ZarrUserWarning from zarr.storage import GpuMemoryStore, MemoryStore from zarr.testing.store import StoreTests from zarr.testing.utils import gpu_test if TYPE_CHECKING: + from zarr.core.buffer import BufferPrototype from zarr.core.common import ZarrFormat @@ -76,75 +79,53 @@ async def test_deterministic_size( np.testing.assert_array_equal(a[:3], 1) np.testing.assert_array_equal(a[3:], 0) - async def test_get_bytes_with_prototype_none(self, store: MemoryStore) -> None: - """Test that get_bytes_async works with prototype=None.""" - from zarr.core.buffer.core import default_buffer_prototype - + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + async def test_get_bytes_with_prototype_none( + self, store: MemoryStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_bytes works with prototype=None.""" data = b"hello world" key = "test_key" await self.set(store, key, self.buffer_cls.from_bytes(data)) - # Test with None (default) - result_none = await store.get_bytes(key) - assert result_none == data - - # Test with explicit prototype - result_proto = await store.get_bytes(key, prototype=default_buffer_prototype()) - assert result_proto == data - - def test_get_bytes_sync_with_prototype_none(self, store: MemoryStore) -> None: - """Test that get_bytes works with prototype=None.""" - from zarr.core.buffer.core import default_buffer_prototype - from zarr.core.sync import sync + result = await store.get_bytes(key, prototype=buffer_cls) + assert result == data + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + def test_get_bytes_sync_with_prototype_none( + self, store: MemoryStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_bytes_sync works with prototype=None.""" data = b"hello world" key = "test_key" sync(self.set(store, key, self.buffer_cls.from_bytes(data))) - # Test with None (default) - result_none = store.get_bytes_sync(key) - assert result_none == data + result = store.get_bytes_sync(key, prototype=buffer_cls) + assert result == data - # Test with explicit prototype - result_proto = store.get_bytes_sync(key, prototype=default_buffer_prototype()) - assert result_proto == data - - async def test_get_json_with_prototype_none(self, store: MemoryStore) -> None: + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + async def test_get_json_with_prototype_none( + self, store: MemoryStore, buffer_cls: None | BufferPrototype + ) -> None: """Test that get_json works with prototype=None.""" - import json - - from zarr.core.buffer.core import default_buffer_prototype - data = {"foo": "bar", "number": 42} key = "test.json" await self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode())) - # Test with None (default) - result_none = await store.get_json(key) - assert result_none == data - - # Test with explicit prototype - result_proto = await store.get_json(key, prototype=default_buffer_prototype()) - assert result_proto == data - - def test_get_json_sync_with_prototype_none(self, store: MemoryStore) -> None: - """Test that get_json works with prototype=None.""" - import json - - from zarr.core.buffer.core import default_buffer_prototype - from zarr.core.sync import sync + result = await store.get_json(key, prototype=buffer_cls) + assert result == data + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + def test_get_json_sync_with_prototype_none( + self, store: MemoryStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_json_sync works with prototype=None.""" data = {"foo": "bar", "number": 42} key = "test.json" sync(self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode()))) - # Test with None (default) - result_none = store.get_json_sync(key) - assert result_none == data - - # Test with explicit prototype - result_proto = store.get_json_sync(key, prototype=default_buffer_prototype()) - assert result_proto == data + result = store.get_json_sync(key, prototype=buffer_cls) + assert result == data # TODO: fix this warning From bdc4ef864b3bcbe422ab981eab6ec94f7af3ac0a Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 8 Jan 2026 11:47:14 +0100 Subject: [PATCH 07/13] refactor new test functions --- tests/test_store/test_local.py | 122 ++++++++++++++------------------- 1 file changed, 52 insertions(+), 70 deletions(-) diff --git a/tests/test_store/test_local.py b/tests/test_store/test_local.py index d6a97110ad..fa4bc7cfc0 100644 --- a/tests/test_store/test_local.py +++ b/tests/test_store/test_local.py @@ -3,6 +3,7 @@ import json import pathlib import re +from typing import TYPE_CHECKING import numpy as np import pytest @@ -10,13 +11,15 @@ import zarr from zarr import create_array from zarr.core.buffer import Buffer, cpu -from zarr.core.buffer.core import BufferPrototype, default_buffer_prototype from zarr.core.sync import sync from zarr.storage import LocalStore from zarr.storage._local import _atomic_write from zarr.testing.store import StoreTests from zarr.testing.utils import assert_bytes_equal +if TYPE_CHECKING: + from zarr.core.buffer import BufferPrototype + class TestLocalStore(StoreTests[LocalStore, cpu.Buffer]): store_cls = LocalStore @@ -111,6 +114,54 @@ async def test_move( ): await store2.move(destination) + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + async def test_get_bytes_with_prototype_none( + self, store: LocalStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_bytes works with prototype=None.""" + data = b"hello world" + key = "test_key" + await self.set(store, key, self.buffer_cls.from_bytes(data)) + + result = await store.get_bytes(key, prototype=buffer_cls) + assert result == data + + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + def test_get_bytes_sync_with_prototype_none( + self, store: LocalStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_bytes_sync works with prototype=None.""" + data = b"hello world" + key = "test_key" + sync(self.set(store, key, self.buffer_cls.from_bytes(data))) + + result = store.get_bytes_sync(key, prototype=buffer_cls) + assert result == data + + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + async def test_get_json_with_prototype_none( + self, store: LocalStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_json works with prototype=None.""" + data = {"foo": "bar", "number": 42} + key = "test.json" + await self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode())) + + result = await store.get_json(key, prototype=buffer_cls) + assert result == data + + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + def test_get_json_sync_with_prototype_none( + self, store: LocalStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_json_sync works with prototype=None.""" + data = {"foo": "bar", "number": 42} + key = "test.json" + sync(self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode()))) + + result = store.get_json_sync(key, prototype=buffer_cls) + assert result == data + @pytest.mark.parametrize("exclusive", [True, False]) def test_atomic_write_successful(tmp_path: pathlib.Path, exclusive: bool) -> None: @@ -153,72 +204,3 @@ def test_atomic_write_exclusive_preexisting(tmp_path: pathlib.Path) -> None: f.write(b"abc") assert path.read_bytes() == b"xyz" assert list(path.parent.iterdir()) == [path] # no temp files - - -async def test_get_bytes_with_prototype_none(tmp_path: pathlib.Path) -> None: - """Test that get_bytes works with prototype=None.""" - from zarr.core.buffer import cpu - - store = await LocalStore.open(root=tmp_path) - data = b"hello world" - key = "test_key" - await store.set(key, cpu.Buffer.from_bytes(data)) - - # Test with None (default) - result_none = await store.get_bytes(key) - assert result_none == data - - # Test with explicit prototype - result_proto = await store.get_bytes(key, prototype=default_buffer_prototype()) - assert result_proto == data - - -def test_get_bytes_sync_with_prototype_none(tmp_path: pathlib.Path) -> None: - """Test that get_bytes_sync works with prototype=None.""" - from zarr.core.buffer import cpu - from zarr.core.sync import sync - - store = sync(LocalStore.open(root=tmp_path)) - data = b"hello world" - key = "test_key" - sync(store.set(key, cpu.Buffer.from_bytes(data))) - - # Test with None (default) - result_none = store.get_bytes_sync(key) - assert result_none == data - - # Test with explicit prototype - result_proto = store.get_bytes_sync(key, prototype=default_buffer_prototype()) - assert result_proto == data - - -@pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) -async def test_get_json_with_prototype_none( - tmp_path: pathlib.Path, buffer_cls: None | BufferPrototype -) -> None: - """Test that get_json works with prototype=None.""" - - store = await LocalStore.open(root=tmp_path) - data = {"foo": "bar", "number": 42} - key = "test.json" - await store.set(key, cpu.Buffer.from_bytes(json.dumps(data).encode())) - - # Test with None (default) - result = await store.get_json(key, prototype=buffer_cls) - assert result == data - - -@pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) -def test_get_json_sync_with_prototype( - tmp_path: pathlib.Path, buffer_cls: None | BufferPrototype -) -> None: - """Test that get_json_sync works with prototype=None.""" - - store = sync(LocalStore.open(root=tmp_path)) - data = {"foo": "bar", "number": 42} - key = "test.json" - sync(store.set(key, cpu.Buffer.from_bytes(json.dumps(data).encode()))) - - # Test with None (default) - result = store.get_json_sync(key, prototype=buffer_cls) - assert result == data From b110768d3557c83a34125cde6382ed4e358301ca Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 8 Jan 2026 16:03:43 +0100 Subject: [PATCH 08/13] add BufferLike as buffer parameter for store methods that allocate memory --- src/zarr/abc/store.py | 90 +++++++++++++++++++++------- src/zarr/experimental/cache_store.py | 18 +++--- src/zarr/storage/_common.py | 21 +++---- src/zarr/storage/_fsspec.py | 42 +++++++++---- src/zarr/storage/_local.py | 45 ++++++++------ src/zarr/storage/_logging.py | 8 +-- src/zarr/storage/_memory.py | 39 +++++++----- src/zarr/storage/_obstore.py | 39 +++++++++--- src/zarr/storage/_wrapper.py | 13 ++-- src/zarr/storage/_zip.py | 27 +++++++-- src/zarr/testing/store.py | 86 ++++++++++++++++++++++++-- tests/test_store/test_wrapper.py | 13 ++-- 12 files changed, 326 insertions(+), 115 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 0e98777ff5..a4eefecf3c 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -7,6 +7,7 @@ from itertools import starmap from typing import TYPE_CHECKING, Literal, Protocol, runtime_checkable +from zarr.core.buffer import Buffer, BufferPrototype from zarr.core.sync import sync if TYPE_CHECKING: @@ -14,9 +15,9 @@ from types import TracebackType from typing import Any, Self, TypeAlias - from zarr.core.buffer import Buffer, BufferPrototype +__all__ = ["BufferLike", "ByteGetter", "ByteSetter", "Store", "set_or_delete"] -__all__ = ["ByteGetter", "ByteSetter", "Store", "set_or_delete"] +BufferLike = type[Buffer] | BufferPrototype @dataclass @@ -183,11 +184,18 @@ def __eq__(self, value: object) -> bool: """Equality comparison.""" ... + @abstractmethod + def _get_default_buffer_class(self) -> type[Buffer]: + """ + Get the default buffer class for this store. + """ + ... + @abstractmethod async def get( self, key: str, - prototype: BufferPrototype, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: """Retrieve the value associated with a given key. @@ -195,8 +203,12 @@ async def get( Parameters ---------- key : str - prototype : BufferPrototype - The prototype of the output buffer. Stores may support a default buffer prototype. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer class for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional ByteRequest may be one of the following. If not provided, all data associated with the key is retrieved. - RangeByteRequest(int, int): Request a specific range of bytes in the form (start, end). The end is exclusive. If the given range is zero-length or starts after the end of the object, an error will be returned. Additionally, if the range ends after the end of the object, the entire remainder of the object will be returned. Otherwise, the exact requested range will be returned. @@ -210,7 +222,11 @@ async def get( ... async def get_bytes( - self, key: str, *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, + key: str, + *, + prototype: BufferLike | None = None, + byte_range: ByteRequest | None = None, ) -> bytes: """ Retrieve raw bytes from the store asynchronously. @@ -222,8 +238,12 @@ async def get_bytes( ---------- key : str The key identifying the data to retrieve. - prototype : BufferPrototype - The buffer prototype to use for reading the data. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer prototype for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. @@ -258,7 +278,11 @@ async def get_bytes( return buffer.to_bytes() def get_bytes_sync( - self, key: str = "", *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, + key: str = "", + *, + prototype: BufferLike | None = None, + byte_range: ByteRequest | None = None, ) -> bytes: """ Retrieve raw bytes from the store synchronously. @@ -271,8 +295,12 @@ def get_bytes_sync( ---------- key : str, optional The key identifying the data to retrieve. Defaults to an empty string. - prototype : BufferPrototype - The buffer prototype to use for reading the data. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer prototype for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. @@ -309,7 +337,11 @@ def get_bytes_sync( return sync(self.get_bytes(key, prototype=prototype, byte_range=byte_range)) async def get_json( - self, key: str, *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, + key: str, + *, + prototype: BufferLike | None = None, + byte_range: ByteRequest | None = None, ) -> Any: """ Retrieve and parse JSON data from the store asynchronously. @@ -321,8 +353,12 @@ async def get_json( ---------- key : str The key identifying the JSON data to retrieve. - prototype : BufferPrototype - The buffer prototype to use for reading the data. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer prototype for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. @@ -359,7 +395,11 @@ async def get_json( return json.loads(await self.get_bytes(key, prototype=prototype, byte_range=byte_range)) def get_json_sync( - self, key: str = "", *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, + key: str = "", + *, + prototype: BufferLike | None = None, + byte_range: ByteRequest | None = None, ) -> Any: """ Retrieve and parse JSON data from the store synchronously. @@ -372,8 +412,12 @@ def get_json_sync( ---------- key : str, optional The key identifying the JSON data to retrieve. Defaults to an empty string. - prototype : BufferPrototype - The buffer prototype to use for reading the data. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer prototype for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. @@ -417,15 +461,19 @@ def get_json_sync( @abstractmethod async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: """Retrieve possibly partial values from given key_ranges. Parameters ---------- - prototype : BufferPrototype - The prototype of the output buffer. Stores may support a default buffer prototype. + prototype : BufferLike | None + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer prototype for this store will be retrieved via the + ``_get_default_buffer_class`` method. key_ranges : Iterable[tuple[str, tuple[int | None, int | None]]] Ordered set of key, range pairs, a key may occur multiple times with different ranges @@ -597,7 +645,7 @@ def close(self) -> None: self._is_open = False async def _get_many( - self, requests: Iterable[tuple[str, BufferPrototype, ByteRequest | None]] + self, requests: Iterable[tuple[str, BufferLike | None, ByteRequest | None]] ) -> AsyncGenerator[tuple[str, Buffer | None], None]: """ Retrieve a collection of objects from storage. In general this method does not guarantee diff --git a/src/zarr/experimental/cache_store.py b/src/zarr/experimental/cache_store.py index 3456c94320..e696e0eb0f 100644 --- a/src/zarr/experimental/cache_store.py +++ b/src/zarr/experimental/cache_store.py @@ -6,13 +6,13 @@ from collections import OrderedDict from typing import TYPE_CHECKING, Any, Literal -from zarr.abc.store import ByteRequest, Store +from zarr.abc.store import BufferLike, ByteRequest, Store from zarr.storage._wrapper import WrapperStore logger = logging.getLogger(__name__) if TYPE_CHECKING: - from zarr.core.buffer.core import Buffer, BufferPrototype + from zarr.core.buffer.core import Buffer class CacheStore(WrapperStore[Store]): @@ -218,7 +218,7 @@ def _remove_from_tracking(self, key: str) -> None: self._key_sizes.pop(key, None) async def _get_try_cache( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, key: str, prototype: BufferLike | None, byte_range: ByteRequest | None = None ) -> Buffer | None: """Try to get data from cache first, falling back to source store.""" maybe_cached_result = await self._cache.get(key, prototype, byte_range) @@ -246,7 +246,7 @@ async def _get_try_cache( return maybe_fresh_result async def _get_no_cache( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, key: str, prototype: BufferLike | None, byte_range: ByteRequest | None = None ) -> Buffer | None: """Get data directly from source store and update cache.""" self._misses += 1 @@ -265,7 +265,7 @@ async def _get_no_cache( async def get( self, key: str, - prototype: BufferPrototype, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: """ @@ -275,8 +275,12 @@ async def get( ---------- key : str The key to retrieve - prototype : BufferPrototype - Buffer prototype for creating the result buffer + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer class for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional Byte range to retrieve diff --git a/src/zarr/storage/_common.py b/src/zarr/storage/_common.py index 4bea04f024..e381c65839 100644 --- a/src/zarr/storage/_common.py +++ b/src/zarr/storage/_common.py @@ -3,10 +3,10 @@ import importlib.util import json from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Self, TypeAlias +from typing import Any, Literal, Self, TypeAlias -from zarr.abc.store import ByteRequest, Store -from zarr.core.buffer import Buffer, default_buffer_prototype +from zarr.abc.store import BufferLike, ByteRequest, Store +from zarr.core.buffer import Buffer from zarr.core.common import ( ANY_ACCESS_MODE, ZARR_JSON, @@ -26,9 +26,6 @@ else: FSMap = None -if TYPE_CHECKING: - from zarr.core.buffer import BufferPrototype - def _dereference_path(root: str, path: str) -> str: if not isinstance(root, str): @@ -145,7 +142,7 @@ async def open(cls, store: Store, path: str, mode: AccessModeLiteral | None = No async def get( self, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: """ @@ -153,8 +150,12 @@ async def get( Parameters ---------- - prototype : BufferPrototype, optional - The buffer prototype to use when reading the bytes. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer class for this store will be retrieved via the + store's ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional The range of bytes to read. @@ -164,7 +165,7 @@ async def get( The read bytes, or None if the key does not exist. """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self.store._get_default_buffer_class() return await self.store.get(self.path, prototype=prototype, byte_range=byte_range) async def set(self, value: Buffer) -> None: diff --git a/src/zarr/storage/_fsspec.py b/src/zarr/storage/_fsspec.py index f9e4ed375d..c8a80a9554 100644 --- a/src/zarr/storage/_fsspec.py +++ b/src/zarr/storage/_fsspec.py @@ -8,13 +8,15 @@ from packaging.version import parse as parse_version from zarr.abc.store import ( + BufferLike, ByteRequest, OffsetByteRequest, RangeByteRequest, Store, SuffixByteRequest, ) -from zarr.core.buffer import Buffer +from zarr.core.buffer import Buffer, BufferPrototype +from zarr.core.buffer.core import default_buffer_prototype from zarr.errors import ZarrUserWarning from zarr.storage._common import _dereference_path @@ -25,8 +27,6 @@ from fsspec.asyn import AsyncFileSystem from fsspec.mapping import FSMap - from zarr.core.buffer import BufferPrototype - ALLOWED_EXCEPTIONS: tuple[type[Exception], ...] = ( FileNotFoundError, @@ -273,22 +273,34 @@ def __eq__(self, other: object) -> bool: and self.fs == other.fs ) + def _get_default_buffer_class(self) -> type[Buffer]: + # docstring inherited + return default_buffer_prototype().buffer + async def get( self, key: str, - prototype: BufferPrototype, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited if not self._is_open: await self._open() + if prototype is None: + prototype = self._get_default_buffer_class() + # Extract buffer class from BufferLike + if isinstance(prototype, BufferPrototype): + buffer_cls = prototype.buffer + else: + buffer_cls = prototype + path = _dereference_path(self.path, key) try: if byte_range is None: - value = prototype.buffer.from_bytes(await self.fs._cat_file(path)) + value = buffer_cls.from_bytes(await self.fs._cat_file(path)) elif isinstance(byte_range, RangeByteRequest): - value = prototype.buffer.from_bytes( + value = buffer_cls.from_bytes( await self.fs._cat_file( path, start=byte_range.start, @@ -296,11 +308,11 @@ async def get( ) ) elif isinstance(byte_range, OffsetByteRequest): - value = prototype.buffer.from_bytes( + value = buffer_cls.from_bytes( await self.fs._cat_file(path, start=byte_range.offset, end=None) ) elif isinstance(byte_range, SuffixByteRequest): - value = prototype.buffer.from_bytes( + value = buffer_cls.from_bytes( await self.fs._cat_file(path, start=-byte_range.suffix, end=None) ) else: @@ -310,7 +322,7 @@ async def get( except OSError as e: if "not satisfiable" in str(e): # this is an s3-specific condition we probably don't want to leak - return prototype.buffer.from_bytes(b"") + return buffer_cls.from_bytes(b"") raise else: return value @@ -367,10 +379,18 @@ async def exists(self, key: str) -> bool: async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited + if prototype is None: + prototype = self._get_default_buffer_class() + # Extract buffer class from BufferLike + if isinstance(prototype, BufferPrototype): + buffer_cls = prototype.buffer + else: + buffer_cls = prototype + if key_ranges: # _cat_ranges expects a list of paths, start, and end ranges, so we need to reformat each ByteRequest. key_ranges = list(key_ranges) @@ -403,7 +423,7 @@ async def get_partial_values( if isinstance(r, Exception) and not isinstance(r, self.allowed_exceptions): raise r - return [None if isinstance(r, Exception) else prototype.buffer.from_bytes(r) for r in res] + return [None if isinstance(r, Exception) else buffer_cls.from_bytes(r) for r in res] async def list(self) -> AsyncIterator[str]: # docstring inherited diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index 9fb3f8b6ad..351b69b275 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -11,37 +11,42 @@ from typing import TYPE_CHECKING, Any, BinaryIO, Literal, Self from zarr.abc.store import ( + BufferLike, ByteRequest, OffsetByteRequest, RangeByteRequest, Store, SuffixByteRequest, ) -from zarr.core.buffer import Buffer +from zarr.core.buffer import Buffer, BufferPrototype from zarr.core.buffer.core import default_buffer_prototype from zarr.core.common import AccessModeLiteral, concurrent_map if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable, Iterator - from zarr.core.buffer import BufferPrototype +def _get(path: Path, prototype: BufferLike, byte_range: ByteRequest | None) -> Buffer: + # Extract buffer class from BufferLike + if isinstance(prototype, BufferPrototype): + buffer_cls = prototype.buffer + else: + buffer_cls = prototype -def _get(path: Path, prototype: BufferPrototype, byte_range: ByteRequest | None) -> Buffer: if byte_range is None: - return prototype.buffer.from_bytes(path.read_bytes()) + return buffer_cls.from_bytes(path.read_bytes()) with path.open("rb") as f: size = f.seek(0, io.SEEK_END) if isinstance(byte_range, RangeByteRequest): f.seek(byte_range.start) - return prototype.buffer.from_bytes(f.read(byte_range.end - f.tell())) + return buffer_cls.from_bytes(f.read(byte_range.end - f.tell())) elif isinstance(byte_range, OffsetByteRequest): f.seek(byte_range.offset) elif isinstance(byte_range, SuffixByteRequest): f.seek(max(0, size - byte_range.suffix)) else: raise TypeError(f"Unexpected byte_range, got {byte_range}.") - return prototype.buffer.from_bytes(f.read()) + return buffer_cls.from_bytes(f.read()) if sys.platform == "win32": @@ -187,15 +192,19 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and self.root == other.root + def _get_default_buffer_class(self) -> type[Buffer]: + # docstring inherited + return default_buffer_prototype().buffer + async def get( self, key: str, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() if not self._is_open: await self._open() assert isinstance(key, str) @@ -208,10 +217,12 @@ async def get( async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited + if prototype is None: + prototype = self._get_default_buffer_class() args = [] for key, byte_range in key_ranges: assert isinstance(key, str) @@ -310,7 +321,7 @@ async def get_bytes( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> bytes: """ @@ -355,14 +366,14 @@ async def get_bytes( b'hello' """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return await super().get_bytes(key, prototype=prototype, byte_range=byte_range) def get_bytes_sync( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> bytes: """ @@ -411,14 +422,14 @@ def get_bytes_sync( b'hello' """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return super().get_bytes_sync(key, prototype=prototype, byte_range=byte_range) async def get_json( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Any: """ @@ -470,14 +481,14 @@ async def get_json( {'zarr_format': 3, 'node_type': 'array'} """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return await super().get_json(key, prototype=prototype, byte_range=byte_range) def get_json_sync( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Any: """ @@ -533,7 +544,7 @@ def get_json_sync( {'zarr_format': 3, 'node_type': 'array'} """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return super().get_json_sync(key, prototype=prototype, byte_range=byte_range) async def move(self, dest_root: Path | str) -> None: diff --git a/src/zarr/storage/_logging.py b/src/zarr/storage/_logging.py index dd20d49ae5..7d82dac948 100644 --- a/src/zarr/storage/_logging.py +++ b/src/zarr/storage/_logging.py @@ -8,14 +8,14 @@ from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Self, TypeVar -from zarr.abc.store import Store +from zarr.abc.store import BufferLike, Store from zarr.storage._wrapper import WrapperStore if TYPE_CHECKING: from collections.abc import AsyncGenerator, Generator, Iterable from zarr.abc.store import ByteRequest - from zarr.core.buffer import Buffer, BufferPrototype + from zarr.core.buffer import Buffer counter: defaultdict[str, int] @@ -165,7 +165,7 @@ def __eq__(self, other: object) -> bool: async def get( self, key: str, - prototype: BufferPrototype, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited @@ -174,7 +174,7 @@ async def get( async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index 1568cc6736..15ee3855df 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -3,8 +3,8 @@ from logging import getLogger from typing import TYPE_CHECKING, Any, Self -from zarr.abc.store import ByteRequest, Store -from zarr.core.buffer import Buffer, gpu +from zarr.abc.store import BufferLike, ByteRequest, Store +from zarr.core.buffer import Buffer, BufferPrototype, gpu from zarr.core.buffer.core import default_buffer_prototype from zarr.core.common import concurrent_map from zarr.storage._utils import _normalize_byte_range_index @@ -12,8 +12,6 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable, MutableMapping - from zarr.core.buffer import BufferPrototype - logger = getLogger(__name__) @@ -60,6 +58,10 @@ def with_read_only(self, read_only: bool = False) -> MemoryStore: read_only=read_only, ) + def _get_default_buffer_class(self) -> type[Buffer]: + # docstring inherited + return default_buffer_prototype().buffer + async def clear(self) -> None: # docstring inherited self._store_dict.clear() @@ -80,25 +82,30 @@ def __eq__(self, other: object) -> bool: async def get( self, key: str, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() + # Extract buffer class from BufferLike + if isinstance(prototype, BufferPrototype): + buffer_cls = prototype.buffer + else: + buffer_cls = prototype if not self._is_open: await self._open() assert isinstance(key, str) try: value = self._store_dict[key] start, stop = _normalize_byte_range_index(value, byte_range) - return prototype.buffer.from_buffer(value[start:stop]) + return buffer_cls.from_buffer(value[start:stop]) except KeyError: return None async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited @@ -179,7 +186,7 @@ async def get_bytes( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> bytes: """ @@ -224,14 +231,14 @@ async def get_bytes( b'hello' """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return await super().get_bytes(key, prototype=prototype, byte_range=byte_range) def get_bytes_sync( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> bytes: """ @@ -280,14 +287,14 @@ def get_bytes_sync( b'hello' """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return super().get_bytes_sync(key, prototype=prototype, byte_range=byte_range) async def get_json( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Any: """ @@ -339,14 +346,14 @@ async def get_json( {'zarr_format': 3, 'node_type': 'array'} """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return await super().get_json(key, prototype=prototype, byte_range=byte_range) def get_json_sync( self, key: str = "", *, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Any: """ @@ -402,7 +409,7 @@ def get_json_sync( {'zarr_format': 3, 'node_type': 'array'} """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() return super().get_json_sync(key, prototype=prototype, byte_range=byte_range) diff --git a/src/zarr/storage/_obstore.py b/src/zarr/storage/_obstore.py index 5c2197ecf6..baa469c430 100644 --- a/src/zarr/storage/_obstore.py +++ b/src/zarr/storage/_obstore.py @@ -7,12 +7,15 @@ from typing import TYPE_CHECKING, Generic, Self, TypedDict, TypeVar from zarr.abc.store import ( + BufferLike, ByteRequest, OffsetByteRequest, RangeByteRequest, Store, SuffixByteRequest, ) +from zarr.core.buffer import BufferPrototype +from zarr.core.buffer.core import default_buffer_prototype from zarr.core.common import concurrent_map from zarr.core.config import config @@ -23,7 +26,7 @@ from obstore import ListResult, ListStream, ObjectMeta, OffsetRange, SuffixRange from obstore.store import ObjectStore as _UpstreamObjectStore - from zarr.core.buffer import Buffer, BufferPrototype + from zarr.core.buffer import Buffer __all__ = ["ObjectStore"] @@ -94,26 +97,40 @@ def __setstate__(self, state: dict[Any, Any]) -> None: state["store"] = pickle.loads(state["store"]) self.__dict__.update(state) + def _get_default_buffer_class(self) -> type[Buffer]: + # docstring inherited + from zarr.core.buffer.core import default_buffer_prototype + + return default_buffer_prototype().buffer + async def get( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, key: str, prototype: BufferLike | None = None, byte_range: ByteRequest | None = None ) -> Buffer | None: # docstring inherited import obstore as obs + if prototype is None: + prototype = self._get_default_buffer_class() + # Extract buffer class from BufferLike + if isinstance(prototype, BufferPrototype): + buffer_cls = prototype.buffer + else: + buffer_cls = prototype + try: if byte_range is None: resp = await obs.get_async(self.store, key) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + return buffer_cls.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] elif isinstance(byte_range, RangeByteRequest): bytes = await obs.get_range_async( self.store, key, start=byte_range.start, end=byte_range.end ) - return prototype.buffer.from_bytes(bytes) # type: ignore[arg-type] + return buffer_cls.from_bytes(bytes) # type: ignore[arg-type] elif isinstance(byte_range, OffsetByteRequest): resp = await obs.get_async( self.store, key, options={"range": {"offset": byte_range.offset}} ) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + return buffer_cls.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] elif isinstance(byte_range, SuffixByteRequest): # some object stores (Azure) don't support suffix requests. In this # case, our workaround is to first get the length of the object and then @@ -122,7 +139,7 @@ async def get( resp = await obs.get_async( self.store, key, options={"range": {"suffix": byte_range.suffix}} ) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + return buffer_cls.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] except obs.exceptions.NotSupportedError: head_resp = await obs.head_async(self.store, key) file_size = head_resp["size"] @@ -133,7 +150,7 @@ async def get( start=file_size - suffix_len, length=suffix_len, ) - return prototype.buffer.from_bytes(buffer) # type: ignore[arg-type] + return buffer_cls.from_bytes(buffer) # type: ignore[arg-type] else: raise ValueError(f"Unexpected byte_range, got {byte_range}") except _ALLOWED_EXCEPTIONS: @@ -141,10 +158,16 @@ async def get( async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited + if prototype is None: + prototype = self._get_default_buffer_class() + # Extract buffer class from BufferLike - _get_partial_values expects BufferPrototype + if not isinstance(prototype, BufferPrototype): + # Convert raw buffer class to BufferPrototype + prototype = default_buffer_prototype() return await _get_partial_values(self.store, prototype=prototype, key_ranges=key_ranges) async def exists(self, key: str) -> bool: diff --git a/src/zarr/storage/_wrapper.py b/src/zarr/storage/_wrapper.py index 64a5b2d83c..b105dcf0d2 100644 --- a/src/zarr/storage/_wrapper.py +++ b/src/zarr/storage/_wrapper.py @@ -9,9 +9,8 @@ from zarr.abc.buffer import Buffer from zarr.abc.store import ByteRequest - from zarr.core.buffer import BufferPrototype -from zarr.abc.store import Store +from zarr.abc.store import BufferLike, Store T_Store = TypeVar("T_Store", bound=Store) @@ -84,14 +83,18 @@ def __str__(self) -> str: def __repr__(self) -> str: return f"WrapperStore({self._store.__class__.__name__}, '{self._store}')" + def _get_default_buffer_class(self) -> type[Buffer]: + # docstring inherited + return self._store._get_default_buffer_class() + async def get( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, key: str, prototype: BufferLike | None = None, byte_range: ByteRequest | None = None ) -> Buffer | None: return await self._store.get(key, prototype, byte_range) async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: return await self._store.get_partial_values(prototype, key_ranges) @@ -139,7 +142,7 @@ def close(self) -> None: self._store.close() async def _get_many( - self, requests: Iterable[tuple[str, BufferPrototype, ByteRequest | None]] + self, requests: Iterable[tuple[str, BufferLike | None, ByteRequest | None]] ) -> AsyncGenerator[tuple[str, Buffer | None], None]: async for req in self._store._get_many(requests): yield req diff --git a/src/zarr/storage/_zip.py b/src/zarr/storage/_zip.py index 72bf9e335a..64fe902e52 100644 --- a/src/zarr/storage/_zip.py +++ b/src/zarr/storage/_zip.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal from zarr.abc.store import ( + BufferLike, ByteRequest, OffsetByteRequest, RangeByteRequest, @@ -16,6 +17,7 @@ SuffixByteRequest, ) from zarr.core.buffer import Buffer, BufferPrototype +from zarr.core.buffer.core import default_buffer_prototype if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable @@ -143,22 +145,31 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and self.path == other.path + def _get_default_buffer_class(self) -> type[Buffer]: + # docstring inherited + return default_buffer_prototype().buffer + def _get( self, key: str, - prototype: BufferPrototype, + prototype: BufferLike, byte_range: ByteRequest | None = None, ) -> Buffer | None: if not self._is_open: self._sync_open() + # Extract buffer class from BufferLike + if isinstance(prototype, BufferPrototype): + buffer_cls = prototype.buffer + else: + buffer_cls = prototype # docstring inherited try: with self._zf.open(key) as f: # will raise KeyError if byte_range is None: - return prototype.buffer.from_bytes(f.read()) + return buffer_cls.from_bytes(f.read()) elif isinstance(byte_range, RangeByteRequest): f.seek(byte_range.start) - return prototype.buffer.from_bytes(f.read(byte_range.end - f.tell())) + return buffer_cls.from_bytes(f.read(byte_range.end - f.tell())) size = f.seek(0, os.SEEK_END) if isinstance(byte_range, OffsetByteRequest): f.seek(byte_range.offset) @@ -166,17 +177,19 @@ def _get( f.seek(max(0, size - byte_range.suffix)) else: raise TypeError(f"Unexpected byte_range, got {byte_range}.") - return prototype.buffer.from_bytes(f.read()) + return buffer_cls.from_bytes(f.read()) except KeyError: return None async def get( self, key: str, - prototype: BufferPrototype, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited + if prototype is None: + prototype = self._get_default_buffer_class() assert isinstance(key, str) with self._lock: @@ -184,10 +197,12 @@ async def get( async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited + if prototype is None: + prototype = self._get_default_buffer_class() out = [] with self._lock: for key, byte_range in key_ranges: diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index a56061ae12..7613dee04a 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -12,11 +12,11 @@ from typing import Any from zarr.abc.store import ByteRequest - from zarr.core.buffer.core import BufferPrototype import pytest from zarr.abc.store import ( + BufferLike, ByteRequest, OffsetByteRequest, RangeByteRequest, @@ -244,6 +244,32 @@ async def test_get_raises(self, store: S) -> None: with pytest.raises((ValueError, TypeError), match=r"Unexpected byte_range, got.*"): await store.get("c/0", prototype=default_buffer_prototype(), byte_range=(0, 2)) # type: ignore[arg-type] + @pytest.mark.parametrize( + "prototype", + [ + None, # Should use store's default buffer class + default_buffer_prototype(), # BufferPrototype instance + default_buffer_prototype().buffer, # Raw Buffer class (cpu.Buffer) + ], + ids=["prototype=None", "prototype=BufferPrototype", "prototype=Buffer"], + ) + async def test_get_with_buffer_like(self, store: S, prototype: BufferLike | None) -> None: + """ + Test that store.get() works with all BufferLike variants: + - None (uses store's default) + - BufferPrototype instance + - Raw Buffer class + """ + data = b"\x01\x02\x03\x04" + key = "test_buffer_like" + data_buf = self.buffer_cls.from_bytes(data) + await self.set(store, key, data_buf) + + # Get with the parametrized prototype + observed = await store.get(key, prototype=prototype) + assert observed is not None + assert_bytes_equal(observed, data_buf) + async def test_get_many(self, store: S) -> None: """ Ensure that multiple keys can be retrieved at once with the _get_many method. @@ -376,6 +402,54 @@ async def test_get_partial_values( obs.to_bytes() == exp.to_bytes() for obs, exp in zip(observed, expected, strict=True) ) + @pytest.mark.parametrize( + "prototype", + [ + None, # Should use store's default buffer class + default_buffer_prototype(), # BufferPrototype instance + default_buffer_prototype().buffer, # Raw Buffer class (cpu.Buffer) + ], + ids=["prototype=None", "prototype=BufferPrototype", "prototype=Buffer"], + ) + async def test_get_partial_values_with_buffer_like( + self, store: S, prototype: BufferLike | None + ) -> None: + """ + Test that store.get_partial_values() works with all BufferLike variants: + - None (uses store's default) + - BufferPrototype instance + - Raw Buffer class + """ + key_ranges: list[tuple[str, ByteRequest | None]] = [ + ("c/0", RangeByteRequest(0, 2)), + ("c/1", None), + ("c/2", SuffixByteRequest(2)), + ] + + # put all of the data + for key, _ in key_ranges: + await self.set(store, key, self.buffer_cls.from_bytes(bytes(key, encoding="utf-8"))) + + # read back with the parametrized prototype + observed_maybe = await store.get_partial_values(prototype=prototype, key_ranges=key_ranges) + + observed: list[Buffer] = [] + expected: list[Buffer] = [] + + for obs in observed_maybe: + assert obs is not None + observed.append(obs) + + for idx in range(len(observed)): + key, byte_range = key_ranges[idx] + result = await store.get(key, prototype=prototype, byte_range=byte_range) + assert result is not None + expected.append(result) + + assert all( + obs.to_bytes() == exp.to_bytes() for obs, exp in zip(observed, expected, strict=True) + ) + async def test_exists(self, store: S) -> None: assert not await store.exists("foo") await store.set("foo/zarr.json", self.buffer_cls.from_bytes(b"bar")) @@ -604,7 +678,7 @@ async def set(self, key: str, value: Buffer) -> None: await self._store.set(key, value) async def get( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, key: str, prototype: BufferLike | None = None, byte_range: ByteRequest | None = None ) -> Buffer | None: """ Add latency to the ``get`` method. @@ -615,8 +689,12 @@ async def get( ---------- key : str The key to get - prototype : BufferPrototype - The BufferPrototype to use. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer class for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional An optional byte range. diff --git a/tests/test_store/test_wrapper.py b/tests/test_store/test_wrapper.py index b34a63d5d0..c5f2240297 100644 --- a/tests/test_store/test_wrapper.py +++ b/tests/test_store/test_wrapper.py @@ -4,7 +4,7 @@ import pytest -from zarr.abc.store import ByteRequest, Store +from zarr.abc.store import BufferLike, ByteRequest, Store from zarr.core.buffer import Buffer from zarr.core.buffer.cpu import Buffer as CPUBuffer from zarr.core.buffer.cpu import buffer_prototype @@ -14,8 +14,6 @@ if TYPE_CHECKING: from pathlib import Path - from zarr.core.buffer.core import BufferPrototype - class StoreKwargs(TypedDict): store: LocalStore @@ -111,10 +109,13 @@ async def test_wrapped_get(store: Store, capsys: pytest.CaptureFixture[str]) -> # define a class that prints when it sets class NoisyGetter(WrapperStore[Any]): async def get( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None - ) -> None: + self, + key: str, + prototype: BufferLike | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: print(f"getting {key}") - await super().get(key, prototype=prototype, byte_range=byte_range) + return await super().get(key, prototype=prototype, byte_range=byte_range) key = "foo" value = CPUBuffer.from_bytes(b"bar") From 6b9de9db2d6594f26a65374b0b25cdcb12b15d7b Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 8 Jan 2026 16:57:38 +0100 Subject: [PATCH 09/13] implement default on store abc --- src/zarr/abc/store.py | 6 +++--- src/zarr/storage/_fsspec.py | 5 ----- src/zarr/storage/_local.py | 5 ----- src/zarr/storage/_memory.py | 5 ----- src/zarr/storage/_obstore.py | 6 ------ src/zarr/storage/_wrapper.py | 4 ---- src/zarr/storage/_zip.py | 5 ----- 7 files changed, 3 insertions(+), 33 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index a4eefecf3c..25aaba4aa9 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Literal, Protocol, runtime_checkable from zarr.core.buffer import Buffer, BufferPrototype +from zarr.core.buffer.core import default_buffer_prototype from zarr.core.sync import sync if TYPE_CHECKING: @@ -184,12 +185,11 @@ def __eq__(self, value: object) -> bool: """Equality comparison.""" ... - @abstractmethod def _get_default_buffer_class(self) -> type[Buffer]: """ - Get the default buffer class for this store. + Get the default buffer class. """ - ... + return default_buffer_prototype().buffer @abstractmethod async def get( diff --git a/src/zarr/storage/_fsspec.py b/src/zarr/storage/_fsspec.py index c8a80a9554..b16712c786 100644 --- a/src/zarr/storage/_fsspec.py +++ b/src/zarr/storage/_fsspec.py @@ -16,7 +16,6 @@ SuffixByteRequest, ) from zarr.core.buffer import Buffer, BufferPrototype -from zarr.core.buffer.core import default_buffer_prototype from zarr.errors import ZarrUserWarning from zarr.storage._common import _dereference_path @@ -273,10 +272,6 @@ def __eq__(self, other: object) -> bool: and self.fs == other.fs ) - def _get_default_buffer_class(self) -> type[Buffer]: - # docstring inherited - return default_buffer_prototype().buffer - async def get( self, key: str, diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index 351b69b275..f991765723 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -19,7 +19,6 @@ SuffixByteRequest, ) from zarr.core.buffer import Buffer, BufferPrototype -from zarr.core.buffer.core import default_buffer_prototype from zarr.core.common import AccessModeLiteral, concurrent_map if TYPE_CHECKING: @@ -192,10 +191,6 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and self.root == other.root - def _get_default_buffer_class(self) -> type[Buffer]: - # docstring inherited - return default_buffer_prototype().buffer - async def get( self, key: str, diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index 15ee3855df..c28dc910b4 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -5,7 +5,6 @@ from zarr.abc.store import BufferLike, ByteRequest, Store from zarr.core.buffer import Buffer, BufferPrototype, gpu -from zarr.core.buffer.core import default_buffer_prototype from zarr.core.common import concurrent_map from zarr.storage._utils import _normalize_byte_range_index @@ -58,10 +57,6 @@ def with_read_only(self, read_only: bool = False) -> MemoryStore: read_only=read_only, ) - def _get_default_buffer_class(self) -> type[Buffer]: - # docstring inherited - return default_buffer_prototype().buffer - async def clear(self) -> None: # docstring inherited self._store_dict.clear() diff --git a/src/zarr/storage/_obstore.py b/src/zarr/storage/_obstore.py index baa469c430..aff000afe9 100644 --- a/src/zarr/storage/_obstore.py +++ b/src/zarr/storage/_obstore.py @@ -97,12 +97,6 @@ def __setstate__(self, state: dict[Any, Any]) -> None: state["store"] = pickle.loads(state["store"]) self.__dict__.update(state) - def _get_default_buffer_class(self) -> type[Buffer]: - # docstring inherited - from zarr.core.buffer.core import default_buffer_prototype - - return default_buffer_prototype().buffer - async def get( self, key: str, prototype: BufferLike | None = None, byte_range: ByteRequest | None = None ) -> Buffer | None: diff --git a/src/zarr/storage/_wrapper.py b/src/zarr/storage/_wrapper.py index b105dcf0d2..ca3609009e 100644 --- a/src/zarr/storage/_wrapper.py +++ b/src/zarr/storage/_wrapper.py @@ -83,10 +83,6 @@ def __str__(self) -> str: def __repr__(self) -> str: return f"WrapperStore({self._store.__class__.__name__}, '{self._store}')" - def _get_default_buffer_class(self) -> type[Buffer]: - # docstring inherited - return self._store._get_default_buffer_class() - async def get( self, key: str, prototype: BufferLike | None = None, byte_range: ByteRequest | None = None ) -> Buffer | None: diff --git a/src/zarr/storage/_zip.py b/src/zarr/storage/_zip.py index 64fe902e52..0348eeedd8 100644 --- a/src/zarr/storage/_zip.py +++ b/src/zarr/storage/_zip.py @@ -17,7 +17,6 @@ SuffixByteRequest, ) from zarr.core.buffer import Buffer, BufferPrototype -from zarr.core.buffer.core import default_buffer_prototype if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable @@ -145,10 +144,6 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and self.path == other.path - def _get_default_buffer_class(self) -> type[Buffer]: - # docstring inherited - return default_buffer_prototype().buffer - def _get( self, key: str, From 281538a2966fa15c229a6de6352c2fefd01b53d7 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 8 Jan 2026 18:00:08 +0100 Subject: [PATCH 10/13] consolidate prototype testing --- src/zarr/testing/store.py | 108 +++++++++----------------------------- 1 file changed, 25 insertions(+), 83 deletions(-) diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 7613dee04a..55e3687f20 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -23,7 +23,7 @@ Store, SuffixByteRequest, ) -from zarr.core.buffer import Buffer, default_buffer_prototype +from zarr.core.buffer import Buffer, cpu, default_buffer_prototype from zarr.core.sync import _collect_aiterator, sync from zarr.storage._utils import _normalize_byte_range_index from zarr.testing.utils import assert_bytes_equal @@ -202,6 +202,15 @@ async def test_with_read_only_store(self, open_kwargs: dict[str, Any]) -> None: ): await reader.delete("foo") + @pytest.mark.parametrize( + "prototype", + [ + None, # Should use store's default buffer class + default_buffer_prototype(), # BufferPrototype instance + default_buffer_prototype().buffer, # Raw Buffer class (cpu.Buffer) + ], + ids=["prototype=None", "prototype=BufferPrototype", "prototype=Buffer"], + ) @pytest.mark.parametrize("key", ["c/0", "foo/c/0.0", "foo/0/0"]) @pytest.mark.parametrize( ("data", "byte_range"), @@ -213,13 +222,15 @@ async def test_with_read_only_store(self, open_kwargs: dict[str, Any]) -> None: (b"", None), ], ) - async def test_get(self, store: S, key: str, data: bytes, byte_range: ByteRequest) -> None: + async def test_get( + self, store: S, key: str, data: bytes, byte_range: ByteRequest, prototype: BufferLike | None + ) -> None: """ Ensure that data can be read from the store using the store.get method. """ data_buf = self.buffer_cls.from_bytes(data) await self.set(store, key, data_buf) - observed = await store.get(key, prototype=default_buffer_prototype(), byte_range=byte_range) + observed = await store.get(key, prototype=prototype, byte_range=byte_range) start, stop = _normalize_byte_range_index(data_buf, byte_range=byte_range) expected = data_buf[start:stop] assert_bytes_equal(observed, expected) @@ -244,32 +255,6 @@ async def test_get_raises(self, store: S) -> None: with pytest.raises((ValueError, TypeError), match=r"Unexpected byte_range, got.*"): await store.get("c/0", prototype=default_buffer_prototype(), byte_range=(0, 2)) # type: ignore[arg-type] - @pytest.mark.parametrize( - "prototype", - [ - None, # Should use store's default buffer class - default_buffer_prototype(), # BufferPrototype instance - default_buffer_prototype().buffer, # Raw Buffer class (cpu.Buffer) - ], - ids=["prototype=None", "prototype=BufferPrototype", "prototype=Buffer"], - ) - async def test_get_with_buffer_like(self, store: S, prototype: BufferLike | None) -> None: - """ - Test that store.get() works with all BufferLike variants: - - None (uses store's default) - - BufferPrototype instance - - Raw Buffer class - """ - data = b"\x01\x02\x03\x04" - key = "test_buffer_like" - data_buf = self.buffer_cls.from_bytes(data) - await self.set(store, key, data_buf) - - # Get with the parametrized prototype - observed = await store.get(key, prototype=prototype) - assert observed is not None - assert_bytes_equal(observed, data_buf) - async def test_get_many(self, store: S) -> None: """ Ensure that multiple keys can be retrieved at once with the _get_many method. @@ -358,6 +343,15 @@ async def test_set_many(self, store: S) -> None: for k, v in store_dict.items(): assert (await self.get(store, k)).to_bytes() == v.to_bytes() + @pytest.mark.parametrize( + "prototype", + [ + None, # Should use store's default buffer class + default_buffer_prototype(), # BufferPrototype instance + default_buffer_prototype().buffer, # Raw Buffer class (cpu.Buffer) + ], + ids=["prototype=None", "prototype=BufferPrototype", "prototype=Buffer"], + ) @pytest.mark.parametrize( "key_ranges", [ @@ -372,65 +366,13 @@ async def test_set_many(self, store: S) -> None: ], ) async def test_get_partial_values( - self, store: S, key_ranges: list[tuple[str, ByteRequest]] + self, store: S, key_ranges: list[tuple[str, ByteRequest]], prototype: BufferLike | None ) -> None: # put all of the data for key, _ in key_ranges: await self.set(store, key, self.buffer_cls.from_bytes(bytes(key, encoding="utf-8"))) # read back just part of it - observed_maybe = await store.get_partial_values( - prototype=default_buffer_prototype(), key_ranges=key_ranges - ) - - observed: list[Buffer] = [] - expected: list[Buffer] = [] - - for obs in observed_maybe: - assert obs is not None - observed.append(obs) - - for idx in range(len(observed)): - key, byte_range = key_ranges[idx] - result = await store.get( - key, prototype=default_buffer_prototype(), byte_range=byte_range - ) - assert result is not None - expected.append(result) - - assert all( - obs.to_bytes() == exp.to_bytes() for obs, exp in zip(observed, expected, strict=True) - ) - - @pytest.mark.parametrize( - "prototype", - [ - None, # Should use store's default buffer class - default_buffer_prototype(), # BufferPrototype instance - default_buffer_prototype().buffer, # Raw Buffer class (cpu.Buffer) - ], - ids=["prototype=None", "prototype=BufferPrototype", "prototype=Buffer"], - ) - async def test_get_partial_values_with_buffer_like( - self, store: S, prototype: BufferLike | None - ) -> None: - """ - Test that store.get_partial_values() works with all BufferLike variants: - - None (uses store's default) - - BufferPrototype instance - - Raw Buffer class - """ - key_ranges: list[tuple[str, ByteRequest | None]] = [ - ("c/0", RangeByteRequest(0, 2)), - ("c/1", None), - ("c/2", SuffixByteRequest(2)), - ] - - # put all of the data - for key, _ in key_ranges: - await self.set(store, key, self.buffer_cls.from_bytes(bytes(key, encoding="utf-8"))) - - # read back with the parametrized prototype observed_maybe = await store.get_partial_values(prototype=prototype, key_ranges=key_ranges) observed: list[Buffer] = [] @@ -442,7 +384,7 @@ async def test_get_partial_values_with_buffer_like( for idx in range(len(observed)): key, byte_range = key_ranges[idx] - result = await store.get(key, prototype=prototype, byte_range=byte_range) + result = await store.get(key, prototype=cpu.Buffer, byte_range=byte_range) assert result is not None expected.append(result) From 690423953d9a8ff32960031fcbef21aefde91d48 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Mon, 26 Jan 2026 13:34:41 +0100 Subject: [PATCH 11/13] remove as much as possible default_buffer_prototype() invocation --- src/zarr/abc/store.py | 35 ++++++++++---------- src/zarr/core/common.py | 17 ++++++++++ src/zarr/core/group.py | 35 +++++++++++--------- src/zarr/core/metadata/io.py | 7 ++-- src/zarr/core/metadata/v2.py | 9 ++++-- src/zarr/core/metadata/v3.py | 7 ++-- src/zarr/experimental/cache_store.py | 11 ++++--- src/zarr/storage/_common.py | 4 +-- src/zarr/storage/_fsspec.py | 23 ++++++------- src/zarr/storage/_local.py | 32 +++++++++--------- src/zarr/storage/_logging.py | 6 ++-- src/zarr/storage/_memory.py | 30 ++++++++--------- src/zarr/storage/_obstore.py | 9 ++++-- src/zarr/storage/_wrapper.py | 11 ++++--- src/zarr/storage/_zip.py | 8 ++--- src/zarr/testing/store.py | 36 +++++++++++++-------- tests/test_api/test_asynchronous.py | 4 +-- tests/test_array.py | 2 +- tests/test_codecs/test_blosc.py | 6 ++-- tests/test_codecs/test_codecs.py | 21 ++++++------ tests/test_codecs/test_sharding.py | 6 ++-- tests/test_experimental/test_cache_store.py | 25 +++++++------- tests/test_group.py | 12 +++---- tests/test_indexing.py | 7 ++-- tests/test_metadata/test_consolidated.py | 8 ++--- tests/test_metadata/test_v2.py | 5 +-- tests/test_metadata/test_v3.py | 7 ++-- tests/test_properties.py | 18 +++-------- tests/test_store/test_fsspec.py | 6 ++-- tests/test_store/test_logging.py | 7 ++-- tests/test_store/test_wrapper.py | 4 +-- tests/test_store/test_zip.py | 4 +-- tests/test_v2.py | 5 ++- 33 files changed, 220 insertions(+), 207 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index b4101ae80c..d58d36e0f4 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -8,17 +8,18 @@ from typing import TYPE_CHECKING, Literal, Protocol, runtime_checkable from zarr.core.buffer import Buffer, BufferPrototype -from zarr.core.buffer.core import default_buffer_prototype from zarr.core.sync import sync +from zarr.registry import get_buffer_class if TYPE_CHECKING: from collections.abc import AsyncGenerator, AsyncIterator, Iterable from types import TracebackType from typing import Any, Self, TypeAlias -__all__ = ["BufferLike", "ByteGetter", "ByteSetter", "Store", "set_or_delete"] +__all__ = ["BufferClassLike", "ByteGetter", "ByteSetter", "Store", "set_or_delete"] -BufferLike = type[Buffer] | BufferPrototype +BufferClassLike = type[Buffer] | BufferPrototype +"""An object that is or contains a Buffer class""" @dataclass @@ -189,13 +190,13 @@ def _get_default_buffer_class(self) -> type[Buffer]: """ Get the default buffer class. """ - return default_buffer_prototype().buffer + return get_buffer_class() @abstractmethod async def get( self, key: str, - prototype: BufferLike | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: """Retrieve the value associated with a given key. @@ -225,7 +226,7 @@ async def _get_bytes( self, key: str, *, - prototype: BufferLike | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> bytes: """ @@ -268,7 +269,7 @@ async def _get_bytes( -------- >>> store = await MemoryStore.open() >>> await store.set("data", Buffer.from_bytes(b"hello world")) - >>> data = await store._get_bytes("data", prototype=default_buffer_prototype()) + >>> data = await store._get_bytes("data") >>> print(data) b'hello world' """ @@ -281,7 +282,7 @@ def _get_bytes_sync( self, key: str = "", *, - prototype: BufferLike | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> bytes: """ @@ -329,7 +330,7 @@ def _get_bytes_sync( -------- >>> store = MemoryStore() >>> await store.set("data", Buffer.from_bytes(b"hello world")) - >>> data = store._get_bytes_sync("data", prototype=default_buffer_prototype()) + >>> data = store._get_bytes_sync("data") >>> print(data) b'hello world' """ @@ -340,7 +341,7 @@ async def _get_json( self, key: str, *, - prototype: BufferLike | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Any: """ @@ -387,7 +388,7 @@ async def _get_json( >>> store = await MemoryStore.open() >>> metadata = {"zarr_format": 3, "node_type": "array"} >>> await store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) - >>> data = await store._get_json("zarr.json", prototype=default_buffer_prototype()) + >>> data = await store._get_json("zarr.json") >>> print(data) {'zarr_format': 3, 'node_type': 'array'} """ @@ -398,7 +399,7 @@ def _get_json_sync( self, key: str = "", *, - prototype: BufferLike | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Any: """ @@ -451,7 +452,7 @@ def _get_json_sync( >>> store = MemoryStore() >>> metadata = {"zarr_format": 3, "node_type": "array"} >>> store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) - >>> data = store._get_json_sync("zarr.json", prototype=default_buffer_prototype()) + >>> data = store._get_json_sync("zarr.json") >>> print(data) {'zarr_format': 3, 'node_type': 'array'} """ @@ -461,7 +462,7 @@ def _get_json_sync( @abstractmethod async def get_partial_values( self, - prototype: BufferLike | None, + prototype: BufferClassLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: """Retrieve possibly partial values from given key_ranges. @@ -645,7 +646,7 @@ def close(self) -> None: self._is_open = False async def _get_many( - self, requests: Iterable[tuple[str, BufferLike | None, ByteRequest | None]] + self, requests: Iterable[tuple[str, BufferClassLike | None, ByteRequest | None]] ) -> AsyncGenerator[tuple[str, Buffer | None], None]: """ Retrieve a collection of objects from storage. In general this method does not guarantee @@ -676,10 +677,8 @@ async def getsize(self, key: str) -> int: # Note to implementers: this default implementation is very inefficient since # it requires reading the entire object. Many systems will have ways to get the # size of an object without reading it. - # avoid circular import - from zarr.core.buffer.core import default_buffer_prototype - value = await self.get(key, prototype=default_buffer_prototype()) + value = await self.get(key) if value is None: raise FileNotFoundError(key) return len(value) diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index d38949657e..9a2bc2d2f3 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -23,12 +23,15 @@ from typing_extensions import ReadOnly +from zarr.core.buffer import Buffer, BufferPrototype from zarr.core.config import config as zarr_config from zarr.errors import ZarrRuntimeWarning if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Iterator + from zarr.abc.store import BufferClassLike + ZARR_JSON = "zarr.json" ZARRAY_JSON = ".zarray" @@ -246,3 +249,17 @@ def _warn_order_kwarg() -> None: def _default_zarr_format() -> ZarrFormat: """Return the default zarr_version""" return cast("ZarrFormat", int(zarr_config.get("default_zarr_format", 3))) + + +def parse_bufferclasslike(obj: BufferClassLike | None) -> type[Buffer]: + """ + Take an optional BufferClassLike and return a Buffer class + """ + # Avoid a circular import. Temporary fix until we re-organize modules appropriately. + from zarr.registry import get_buffer_class + + if obj is None: + return get_buffer_class() + if isinstance(obj, BufferPrototype): + return obj.buffer + return obj diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 9b5fee275b..fadb700fd8 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -17,7 +17,7 @@ import zarr.api.asynchronous as async_api from zarr.abc.metadata import Metadata -from zarr.abc.store import Store, set_or_delete +from zarr.abc.store import BufferClassLike, Store, set_or_delete from zarr.core._info import GroupInfo from zarr.core.array import ( DEFAULT_FILL_VALUE, @@ -32,7 +32,6 @@ create_array, ) from zarr.core.attributes import Attributes -from zarr.core.buffer import default_buffer_prototype from zarr.core.common import ( JSON, ZARR_JSON, @@ -44,6 +43,7 @@ NodeType, ShapeLike, ZarrFormat, + parse_bufferclasslike, parse_shapelike, ) from zarr.core.config import config @@ -75,7 +75,7 @@ from typing import Any from zarr.core.array_spec import ArrayConfigLike - from zarr.core.buffer import Buffer, BufferPrototype + from zarr.core.buffer import Buffer from zarr.core.chunk_key_encodings import ChunkKeyEncodingLike from zarr.core.common import MemoryOrder from zarr.core.dtype import ZDTypeLike @@ -356,20 +356,25 @@ class GroupMetadata(Metadata): consolidated_metadata: ConsolidatedMetadata | None = None node_type: Literal["group"] = field(default="group", init=False) - def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: + def to_buffer_dict(self, prototype: BufferClassLike | None = None) -> dict[str, Buffer]: + """ + Convert the metadata document to a dict with string keys and `Buffer` values. + """ + buffer_cls = parse_bufferclasslike(prototype) + json_indent = config.get("json_indent") if self.zarr_format == 3: return { - ZARR_JSON: prototype.buffer.from_bytes( + ZARR_JSON: buffer_cls.from_bytes( json.dumps(self.to_dict(), indent=json_indent, allow_nan=True).encode() ) } else: items = { - ZGROUP_JSON: prototype.buffer.from_bytes( + ZGROUP_JSON: buffer_cls.from_bytes( json.dumps({"zarr_format": self.zarr_format}, indent=json_indent).encode() ), - ZATTRS_JSON: prototype.buffer.from_bytes( + ZATTRS_JSON: buffer_cls.from_bytes( json.dumps(self.attributes, indent=json_indent, allow_nan=True).encode() ), } @@ -396,7 +401,7 @@ def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: }, } - items[ZMETADATA_V2_JSON] = prototype.buffer.from_bytes( + items[ZMETADATA_V2_JSON] = buffer_cls.from_bytes( json.dumps( {"metadata": d, "zarr_consolidated_format": 1}, allow_nan=True ).encode() @@ -2029,7 +2034,7 @@ async def update_attributes_async(self, new_attributes: dict[str, Any]) -> Group new_metadata = replace(self.metadata, attributes=new_attributes) # Write new metadata - to_save = new_metadata.to_buffer_dict(default_buffer_prototype()) + to_save = new_metadata.to_buffer_dict() awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()] await asyncio.gather(*awaitables) @@ -3615,9 +3620,7 @@ async def _read_metadata_v3(store: Store, path: str) -> ArrayV3Metadata | GroupM document stored at store_path.path / zarr.json. If no such document is found, raise a FileNotFoundError. """ - zarr_json_bytes = await store.get( - _join_paths([path, ZARR_JSON]), prototype=default_buffer_prototype() - ) + zarr_json_bytes = await store.get(_join_paths([path, ZARR_JSON])) if zarr_json_bytes is None: raise FileNotFoundError(path) else: @@ -3634,9 +3637,9 @@ async def _read_metadata_v2(store: Store, path: str) -> ArrayV2Metadata | GroupM # TODO: consider first fetching array metadata, and only fetching group metadata when we don't # find an array zarray_bytes, zgroup_bytes, zattrs_bytes = await asyncio.gather( - store.get(_join_paths([path, ZARRAY_JSON]), prototype=default_buffer_prototype()), - store.get(_join_paths([path, ZGROUP_JSON]), prototype=default_buffer_prototype()), - store.get(_join_paths([path, ZATTRS_JSON]), prototype=default_buffer_prototype()), + store.get(_join_paths([path, ZARRAY_JSON])), + store.get(_join_paths([path, ZGROUP_JSON])), + store.get(_join_paths([path, ZATTRS_JSON])), ) if zattrs_bytes is None: @@ -3850,7 +3853,7 @@ def _persist_metadata( Prepare to save a metadata document to storage, returning a tuple of coroutines that must be awaited. """ - to_save = metadata.to_buffer_dict(default_buffer_prototype()) + to_save = metadata.to_buffer_dict() return tuple( _set_return_key(store=store, key=_join_paths([path, key]), value=value, semaphore=semaphore) for key, value in to_save.items() diff --git a/src/zarr/core/metadata/io.py b/src/zarr/core/metadata/io.py index 7b63f5493b..0a04063cd7 100644 --- a/src/zarr/core/metadata/io.py +++ b/src/zarr/core/metadata/io.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING from zarr.abc.store import set_or_delete -from zarr.core.buffer.core import default_buffer_prototype from zarr.errors import ContainsArrayError from zarr.storage._common import StorePath, ensure_no_existing_node @@ -51,7 +50,7 @@ async def save_metadata( ------ ValueError """ - to_save = metadata.to_buffer_dict(default_buffer_prototype()) + to_save = metadata.to_buffer_dict() set_awaitables = [set_or_delete(store_path / key, value) for key, value in to_save.items()] if ensure_parents: @@ -71,9 +70,7 @@ async def save_metadata( set_awaitables.extend( [ (parent_store_path / key).set_if_not_exists(value) - for key, value in parent_metadata.to_buffer_dict( - default_buffer_prototype() - ).items() + for key, value in parent_metadata.to_buffer_dict().items() ] ) diff --git a/src/zarr/core/metadata/v2.py b/src/zarr/core/metadata/v2.py index 3204543426..153305df32 100644 --- a/src/zarr/core/metadata/v2.py +++ b/src/zarr/core/metadata/v2.py @@ -18,6 +18,7 @@ import numpy.typing as npt + from zarr.abc.store import BufferClassLike from zarr.core.buffer import Buffer, BufferPrototype from zarr.core.dtype.wrapper import ( TBaseDType, @@ -39,6 +40,7 @@ ZARRAY_JSON, ZATTRS_JSON, MemoryOrder, + parse_bufferclasslike, parse_shapelike, ) from zarr.core.config import config, parse_indexing_order @@ -125,15 +127,16 @@ def chunk_grid(self) -> RegularChunkGrid: def shards(self) -> tuple[int, ...] | None: return None - def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: + def to_buffer_dict(self, prototype: BufferClassLike | None = None) -> dict[str, Buffer]: + buffer_cls = parse_bufferclasslike(prototype) zarray_dict = self.to_dict() zattrs_dict = zarray_dict.pop("attributes", {}) json_indent = config.get("json_indent") return { - ZARRAY_JSON: prototype.buffer.from_bytes( + ZARRAY_JSON: buffer_cls.from_bytes( json.dumps(zarray_dict, indent=json_indent, allow_nan=True).encode() ), - ZATTRS_JSON: prototype.buffer.from_bytes( + ZATTRS_JSON: buffer_cls.from_bytes( json.dumps(zattrs_dict, indent=json_indent, allow_nan=True).encode() ), } diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 5ce155bd9a..3d4957c17f 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from typing import Self + from zarr.abc.store import BufferClassLike from zarr.core.buffer import Buffer, BufferPrototype from zarr.core.chunk_grids import ChunkGrid from zarr.core.common import JSON @@ -35,6 +36,7 @@ ZARR_JSON, DimensionNames, NamedConfig, + parse_bufferclasslike, parse_named_configuration, parse_shapelike, ) @@ -345,11 +347,12 @@ def get_chunk_spec( def encode_chunk_key(self, chunk_coords: tuple[int, ...]) -> str: return self.chunk_key_encoding.encode_chunk_key(chunk_coords) - def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: + def to_buffer_dict(self, prototype: BufferClassLike | None = None) -> dict[str, Buffer]: + buffer_cls = parse_bufferclasslike(prototype) json_indent = config.get("json_indent") d = self.to_dict() return { - ZARR_JSON: prototype.buffer.from_bytes( + ZARR_JSON: buffer_cls.from_bytes( json.dumps(d, allow_nan=True, indent=json_indent).encode() ) } diff --git a/src/zarr/experimental/cache_store.py b/src/zarr/experimental/cache_store.py index e696e0eb0f..e55cef16b2 100644 --- a/src/zarr/experimental/cache_store.py +++ b/src/zarr/experimental/cache_store.py @@ -6,7 +6,7 @@ from collections import OrderedDict from typing import TYPE_CHECKING, Any, Literal -from zarr.abc.store import BufferLike, ByteRequest, Store +from zarr.abc.store import BufferClassLike, ByteRequest, Store from zarr.storage._wrapper import WrapperStore logger = logging.getLogger(__name__) @@ -218,7 +218,7 @@ def _remove_from_tracking(self, key: str) -> None: self._key_sizes.pop(key, None) async def _get_try_cache( - self, key: str, prototype: BufferLike | None, byte_range: ByteRequest | None = None + self, key: str, prototype: BufferClassLike | None, byte_range: ByteRequest | None = None ) -> Buffer | None: """Try to get data from cache first, falling back to source store.""" maybe_cached_result = await self._cache.get(key, prototype, byte_range) @@ -246,7 +246,10 @@ async def _get_try_cache( return maybe_fresh_result async def _get_no_cache( - self, key: str, prototype: BufferLike | None, byte_range: ByteRequest | None = None + self, + key: str, + prototype: BufferClassLike | None = None, + byte_range: ByteRequest | None = None, ) -> Buffer | None: """Get data directly from source store and update cache.""" self._misses += 1 @@ -265,7 +268,7 @@ async def _get_no_cache( async def get( self, key: str, - prototype: BufferLike | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: """ diff --git a/src/zarr/storage/_common.py b/src/zarr/storage/_common.py index e381c65839..ddb6a38bc7 100644 --- a/src/zarr/storage/_common.py +++ b/src/zarr/storage/_common.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Any, Literal, Self, TypeAlias -from zarr.abc.store import BufferLike, ByteRequest, Store +from zarr.abc.store import BufferClassLike, ByteRequest, Store from zarr.core.buffer import Buffer from zarr.core.common import ( ANY_ACCESS_MODE, @@ -142,7 +142,7 @@ async def open(cls, store: Store, path: str, mode: AccessModeLiteral | None = No async def get( self, - prototype: BufferLike | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: """ diff --git a/src/zarr/storage/_fsspec.py b/src/zarr/storage/_fsspec.py index b16712c786..faeacacd97 100644 --- a/src/zarr/storage/_fsspec.py +++ b/src/zarr/storage/_fsspec.py @@ -8,14 +8,15 @@ from packaging.version import parse as parse_version from zarr.abc.store import ( - BufferLike, + BufferClassLike, ByteRequest, OffsetByteRequest, RangeByteRequest, Store, SuffixByteRequest, ) -from zarr.core.buffer import Buffer, BufferPrototype +from zarr.core.buffer import Buffer +from zarr.core.common import parse_bufferclasslike from zarr.errors import ZarrUserWarning from zarr.storage._common import _dereference_path @@ -275,19 +276,16 @@ def __eq__(self, other: object) -> bool: async def get( self, key: str, - prototype: BufferLike | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited if not self._is_open: await self._open() if prototype is None: - prototype = self._get_default_buffer_class() - # Extract buffer class from BufferLike - if isinstance(prototype, BufferPrototype): - buffer_cls = prototype.buffer + buffer_cls = self._get_default_buffer_class() else: - buffer_cls = prototype + buffer_cls = parse_bufferclasslike(prototype) path = _dereference_path(self.path, key) @@ -374,17 +372,14 @@ async def exists(self, key: str) -> bool: async def get_partial_values( self, - prototype: BufferLike | None, + prototype: BufferClassLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited if prototype is None: - prototype = self._get_default_buffer_class() - # Extract buffer class from BufferLike - if isinstance(prototype, BufferPrototype): - buffer_cls = prototype.buffer + buffer_cls = self._get_default_buffer_class() else: - buffer_cls = prototype + buffer_cls = parse_bufferclasslike(prototype) if key_ranges: # _cat_ranges expects a list of paths, start, and end ranges, so we need to reformat each ByteRequest. diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index 99750ae3b5..b26f00f8ed 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Any, BinaryIO, Literal, Self from zarr.abc.store import ( - BufferLike, + BufferClassLike, ByteRequest, OffsetByteRequest, RangeByteRequest, @@ -25,7 +25,7 @@ from collections.abc import AsyncIterator, Iterable, Iterator -def _get(path: Path, prototype: BufferLike, byte_range: ByteRequest | None) -> Buffer: +def _get(path: Path, prototype: BufferClassLike, byte_range: ByteRequest | None) -> Buffer: # Extract buffer class from BufferLike if isinstance(prototype, BufferPrototype): buffer_cls = prototype.buffer @@ -194,7 +194,7 @@ def __eq__(self, other: object) -> bool: async def get( self, key: str, - prototype: BufferLike | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited @@ -212,7 +212,7 @@ async def get( async def get_partial_values( self, - prototype: BufferLike | None, + prototype: BufferClassLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited @@ -316,7 +316,7 @@ async def _get_bytes( self, key: str = "", *, - prototype: BufferLike | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> bytes: """ @@ -330,8 +330,8 @@ async def _get_bytes( ---------- key : str, optional The key identifying the data to retrieve. Defaults to an empty string. - prototype : BufferPrototype, optional - The buffer prototype to use for reading the data. If None, uses + prototype : BufferClassLike, optional + A specification of the buffer class to use for reading the data. If None, uses ``default_buffer_prototype()``. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. @@ -368,7 +368,7 @@ def _get_bytes_sync( self, key: str = "", *, - prototype: BufferLike | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> bytes: """ @@ -382,8 +382,8 @@ def _get_bytes_sync( ---------- key : str, optional The key identifying the data to retrieve. Defaults to an empty string. - prototype : BufferPrototype, optional - The buffer prototype to use for reading the data. If None, uses + prototype : BufferClassLike, optional + A specification of the buffer class to use for reading the data. If None, uses ``default_buffer_prototype()``. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. @@ -424,7 +424,7 @@ async def _get_json( self, key: str = "", *, - prototype: BufferLike | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Any: """ @@ -438,8 +438,8 @@ async def _get_json( ---------- key : str, optional The key identifying the JSON data to retrieve. Defaults to an empty string. - prototype : BufferPrototype, optional - The buffer prototype to use for reading the data. If None, uses + prototype : BufferClassLike, optional + A specification of the buffer class to use for reading the data. If None, uses ``default_buffer_prototype()``. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. @@ -483,7 +483,7 @@ def _get_json_sync( self, key: str = "", *, - prototype: BufferLike | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Any: """ @@ -497,8 +497,8 @@ def _get_json_sync( ---------- key : str, optional The key identifying the JSON data to retrieve. Defaults to an empty string. - prototype : BufferPrototype, optional - The buffer prototype to use for reading the data. If None, uses + prototype : BufferClassLike, optional + A specification of the buffer class to use for reading the data. If None, uses ``default_buffer_prototype()``. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. diff --git a/src/zarr/storage/_logging.py b/src/zarr/storage/_logging.py index 7d82dac948..b3b5030cd0 100644 --- a/src/zarr/storage/_logging.py +++ b/src/zarr/storage/_logging.py @@ -8,7 +8,7 @@ from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Self, TypeVar -from zarr.abc.store import BufferLike, Store +from zarr.abc.store import BufferClassLike, Store from zarr.storage._wrapper import WrapperStore if TYPE_CHECKING: @@ -165,7 +165,7 @@ def __eq__(self, other: object) -> bool: async def get( self, key: str, - prototype: BufferLike | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited @@ -174,7 +174,7 @@ async def get( async def get_partial_values( self, - prototype: BufferLike | None, + prototype: BufferClassLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index aa00f664c9..e21c9d7c42 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -3,7 +3,7 @@ from logging import getLogger from typing import TYPE_CHECKING, Any, Self -from zarr.abc.store import BufferLike, ByteRequest, Store +from zarr.abc.store import BufferClassLike, ByteRequest, Store from zarr.core.buffer import Buffer, BufferPrototype, gpu from zarr.core.common import concurrent_map from zarr.storage._utils import _normalize_byte_range_index @@ -77,7 +77,7 @@ def __eq__(self, other: object) -> bool: async def get( self, key: str, - prototype: BufferLike | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited @@ -100,7 +100,7 @@ async def get( async def get_partial_values( self, - prototype: BufferLike | None, + prototype: BufferClassLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited @@ -181,7 +181,7 @@ async def _get_bytes( self, key: str = "", *, - prototype: BufferLike | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> bytes: """ @@ -195,8 +195,8 @@ async def _get_bytes( ---------- key : str, optional The key identifying the data to retrieve. Defaults to an empty string. - prototype : BufferPrototype, optional - The buffer prototype to use for reading the data. If None, uses + prototype : BufferClassLike, optional + A specification of the buffer class to use for reading the data. If None, uses ``default_buffer_prototype()``. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. @@ -233,7 +233,7 @@ def _get_bytes_sync( self, key: str = "", *, - prototype: BufferLike | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> bytes: """ @@ -247,8 +247,8 @@ def _get_bytes_sync( ---------- key : str, optional The key identifying the data to retrieve. Defaults to an empty string. - prototype : BufferPrototype, optional - The buffer prototype to use for reading the data. If None, uses + prototype : BufferClassLike, optional + A specification of the buffer class to use for reading the data. If None, uses ``default_buffer_prototype()``. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. @@ -289,7 +289,7 @@ async def _get_json( self, key: str = "", *, - prototype: BufferLike | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Any: """ @@ -303,8 +303,8 @@ async def _get_json( ---------- key : str, optional The key identifying the JSON data to retrieve. Defaults to an empty string. - prototype : BufferPrototype, optional - The buffer prototype to use for reading the data. If None, uses + prototype : BufferClassLike, optional + A specification of the buffer class to use for reading the data. If None, uses ``default_buffer_prototype()``. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. @@ -348,7 +348,7 @@ def _get_json_sync( self, key: str = "", *, - prototype: BufferLike | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Any: """ @@ -362,8 +362,8 @@ def _get_json_sync( ---------- key : str, optional The key identifying the JSON data to retrieve. Defaults to an empty string. - prototype : BufferPrototype, optional - The buffer prototype to use for reading the data. If None, uses + prototype : BufferClassLike, optional + A specification of the buffer class to use for reading the data. If None, uses ``default_buffer_prototype()``. byte_range : ByteRequest, optional If specified, only retrieve a portion of the stored data. diff --git a/src/zarr/storage/_obstore.py b/src/zarr/storage/_obstore.py index aff000afe9..f836bb58a1 100644 --- a/src/zarr/storage/_obstore.py +++ b/src/zarr/storage/_obstore.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Generic, Self, TypedDict, TypeVar from zarr.abc.store import ( - BufferLike, + BufferClassLike, ByteRequest, OffsetByteRequest, RangeByteRequest, @@ -98,7 +98,10 @@ def __setstate__(self, state: dict[Any, Any]) -> None: self.__dict__.update(state) async def get( - self, key: str, prototype: BufferLike | None = None, byte_range: ByteRequest | None = None + self, + key: str, + prototype: BufferClassLike | None = None, + byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited import obstore as obs @@ -152,7 +155,7 @@ async def get( async def get_partial_values( self, - prototype: BufferLike | None, + prototype: BufferClassLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited diff --git a/src/zarr/storage/_wrapper.py b/src/zarr/storage/_wrapper.py index ca3609009e..13b75edc2f 100644 --- a/src/zarr/storage/_wrapper.py +++ b/src/zarr/storage/_wrapper.py @@ -10,7 +10,7 @@ from zarr.abc.buffer import Buffer from zarr.abc.store import ByteRequest -from zarr.abc.store import BufferLike, Store +from zarr.abc.store import BufferClassLike, Store T_Store = TypeVar("T_Store", bound=Store) @@ -84,13 +84,16 @@ def __repr__(self) -> str: return f"WrapperStore({self._store.__class__.__name__}, '{self._store}')" async def get( - self, key: str, prototype: BufferLike | None = None, byte_range: ByteRequest | None = None + self, + key: str, + prototype: BufferClassLike | None = None, + byte_range: ByteRequest | None = None, ) -> Buffer | None: return await self._store.get(key, prototype, byte_range) async def get_partial_values( self, - prototype: BufferLike | None, + prototype: BufferClassLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: return await self._store.get_partial_values(prototype, key_ranges) @@ -138,7 +141,7 @@ def close(self) -> None: self._store.close() async def _get_many( - self, requests: Iterable[tuple[str, BufferLike | None, ByteRequest | None]] + self, requests: Iterable[tuple[str, BufferClassLike | None, ByteRequest | None]] ) -> AsyncGenerator[tuple[str, Buffer | None], None]: async for req in self._store._get_many(requests): yield req diff --git a/src/zarr/storage/_zip.py b/src/zarr/storage/_zip.py index 0348eeedd8..eb3b33f514 100644 --- a/src/zarr/storage/_zip.py +++ b/src/zarr/storage/_zip.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal from zarr.abc.store import ( - BufferLike, + BufferClassLike, ByteRequest, OffsetByteRequest, RangeByteRequest, @@ -147,7 +147,7 @@ def __eq__(self, other: object) -> bool: def _get( self, key: str, - prototype: BufferLike, + prototype: BufferClassLike, byte_range: ByteRequest | None = None, ) -> Buffer | None: if not self._is_open: @@ -179,7 +179,7 @@ def _get( async def get( self, key: str, - prototype: BufferLike | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited @@ -192,7 +192,7 @@ async def get( async def get_partial_values( self, - prototype: BufferLike | None, + prototype: BufferClassLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 9337f86e11..bc6a8e09e6 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -16,7 +16,7 @@ import pytest from zarr.abc.store import ( - BufferLike, + BufferClassLike, ByteRequest, OffsetByteRequest, RangeByteRequest, @@ -108,7 +108,7 @@ async def test_serializable_store(self, store: S) -> None: data_buf = self.buffer_cls.from_bytes(b"\x01\x02\x03\x04") key = "foo" await store.set(key, data_buf) - observed = await store.get(key, prototype=default_buffer_prototype()) + observed = await store.get(key) assert_bytes_equal(observed, data_buf) def test_store_read_only(self, store: S) -> None: @@ -223,7 +223,12 @@ async def test_with_read_only_store(self, open_kwargs: dict[str, Any]) -> None: ], ) async def test_get( - self, store: S, key: str, data: bytes, byte_range: ByteRequest, prototype: BufferLike | None + self, + store: S, + key: str, + data: bytes, + byte_range: ByteRequest, + prototype: BufferClassLike | None, ) -> None: """ Ensure that data can be read from the store using the store.get method. @@ -243,7 +248,7 @@ async def test_get_not_open(self, store_not_open: S) -> None: data_buf = self.buffer_cls.from_bytes(b"\x01\x02\x03\x04") key = "c/0" await self.set(store_not_open, key, data_buf) - observed = await store_not_open.get(key, prototype=default_buffer_prototype()) + observed = await store_not_open.get(key) assert_bytes_equal(observed, data_buf) async def test_get_raises(self, store: S) -> None: @@ -267,7 +272,7 @@ async def test_get_many(self, store: S) -> None: store._get_many( zip( keys, - (default_buffer_prototype(),) * len(keys), + (None,) * len(keys), (None,) * len(keys), strict=False, ) @@ -366,7 +371,7 @@ async def test_set_many(self, store: S) -> None: ], ) async def test_get_partial_values( - self, store: S, key_ranges: list[tuple[str, ByteRequest]], prototype: BufferLike | None + self, store: S, key_ranges: list[tuple[str, ByteRequest]], prototype: BufferClassLike | None ) -> None: # put all of the data for key, _ in key_ranges: @@ -535,12 +540,12 @@ async def test_set_if_not_exists(self, store: S) -> None: new = self.buffer_cls.from_bytes(b"1111") await store.set_if_not_exists("k", new) # no error - result = await store.get(key, default_buffer_prototype()) + result = await store.get(key) assert result == data_buf await store.set_if_not_exists("k2", new) # no error - result = await store.get("k2", default_buffer_prototype()) + result = await store.get("k2") assert result == new async def test_get_bytes(self, store: S) -> None: @@ -550,9 +555,9 @@ async def test_get_bytes(self, store: S) -> None: data = b"hello world" key = "zarr.json" await self.set(store, key, self.buffer_cls.from_bytes(data)) - assert await store._get_bytes(key, prototype=default_buffer_prototype()) == data + assert await store._get_bytes(key) == data with pytest.raises(FileNotFoundError): - await store._get_bytes("nonexistent_key", prototype=default_buffer_prototype()) + await store._get_bytes("nonexistent_key") def test_get_bytes_sync(self, store: S) -> None: """ @@ -561,7 +566,7 @@ def test_get_bytes_sync(self, store: S) -> None: data = b"hello world" key = "zarr.json" sync(self.set(store, key, self.buffer_cls.from_bytes(data))) - assert store._get_bytes_sync(key, prototype=default_buffer_prototype()) == data + assert store._get_bytes_sync(key) == data async def test_get_json(self, store: S) -> None: """ @@ -571,7 +576,7 @@ async def test_get_json(self, store: S) -> None: data_bytes = json.dumps(data).encode("utf-8") key = "zarr.json" await self.set(store, key, self.buffer_cls.from_bytes(data_bytes)) - assert await store._get_json(key, prototype=default_buffer_prototype()) == data + assert await store._get_json(key) == data def test_get_json_sync(self, store: S) -> None: """ @@ -581,7 +586,7 @@ def test_get_json_sync(self, store: S) -> None: data_bytes = json.dumps(data).encode("utf-8") key = "zarr.json" sync(self.set(store, key, self.buffer_cls.from_bytes(data_bytes))) - assert store._get_json_sync(key, prototype=default_buffer_prototype()) == data + assert store._get_json_sync(key) == data class LatencyStore(WrapperStore[Store]): @@ -620,7 +625,10 @@ async def set(self, key: str, value: Buffer) -> None: await self._store.set(key, value) async def get( - self, key: str, prototype: BufferLike | None = None, byte_range: ByteRequest | None = None + self, + key: str, + prototype: BufferClassLike | None = None, + byte_range: ByteRequest | None = None, ) -> Buffer | None: """ Add latency to the ``get`` method. diff --git a/tests/test_api/test_asynchronous.py b/tests/test_api/test_asynchronous.py index 362195e858..e8035493ae 100644 --- a/tests/test_api/test_asynchronous.py +++ b/tests/test_api/test_asynchronous.py @@ -9,7 +9,7 @@ from zarr import create_array from zarr.api.asynchronous import _get_shape_chunks, _like_args, group, open -from zarr.core.buffer.core import default_buffer_prototype +from zarr.buffer import cpu from zarr.core.group import AsyncGroup if TYPE_CHECKING: @@ -101,7 +101,7 @@ async def test_open_no_array() -> None: This behavior makes no sense but we should still test it. """ store = { - "zarr.json": default_buffer_prototype().buffer.from_bytes( + "zarr.json": cpu.Buffer.from_bytes( json.dumps({"zarr_format": 3, "node_type": "group"}).encode("utf-8") ) } diff --git a/tests/test_array.py b/tests/test_array.py index 67be294827..7b17e63610 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -941,7 +941,7 @@ def test_write_empty_chunks_negative_zero( async def test_special_complex_fill_values_roundtrip(fill_value: Any, expected: list[Any]) -> None: store = MemoryStore() zarr.create_array(store=store, shape=(1,), dtype=np.complex64, fill_value=fill_value) - content = await store.get("zarr.json", prototype=default_buffer_prototype()) + content = await store.get("zarr.json") assert content is not None actual = json.loads(content.to_bytes()) assert actual["fill_value"] == expected diff --git a/tests/test_codecs/test_blosc.py b/tests/test_codecs/test_blosc.py index 6f4821f8b1..639e5c2ca5 100644 --- a/tests/test_codecs/test_blosc.py +++ b/tests/test_codecs/test_blosc.py @@ -28,7 +28,7 @@ async def test_blosc_evolve(dtype: str) -> None: fill_value=0, compressors=BloscCodec(), ) - buf = await store.get(f"{path}/zarr.json", prototype=default_buffer_prototype()) + buf = await store.get(f"{path}/zarr.json") assert buf is not None zarr_json = json.loads(buf.to_bytes()) blosc_configuration_json = zarr_json["codecs"][1]["configuration"] @@ -49,7 +49,7 @@ async def test_blosc_evolve(dtype: str) -> None: fill_value=0, compressors=BloscCodec(), ) - buf = await store.get(f"{path2}/zarr.json", prototype=default_buffer_prototype()) + buf = await store.get(f"{path2}/zarr.json") assert buf is not None zarr_json = json.loads(buf.to_bytes()) blosc_configuration_json = zarr_json["codecs"][0]["configuration"]["codecs"][1]["configuration"] @@ -99,7 +99,7 @@ async def test_typesize() -> None: a = np.arange(1000000, dtype=np.uint64) codecs = [zarr.codecs.BytesCodec(), zarr.codecs.BloscCodec()] z = zarr.array(a, chunks=(10000), codecs=codecs) - data = await z.store.get("c/0", prototype=default_buffer_prototype()) + data = await z.store.get("c/0") assert data is not None bytes = data.to_bytes() size = len(bytes) diff --git a/tests/test_codecs/test_codecs.py b/tests/test_codecs/test_codecs.py index eae7168d49..3994a8a55e 100644 --- a/tests/test_codecs/test_codecs.py +++ b/tests/test_codecs/test_codecs.py @@ -17,7 +17,6 @@ ShardingCodec, TransposeCodec, ) -from zarr.core.buffer import default_buffer_prototype from zarr.core.indexing import BasicSelection, morton_order_iter from zarr.core.metadata.v3 import ArrayV3Metadata from zarr.dtype import UInt8 @@ -253,7 +252,7 @@ async def test_delete_empty_chunks(store: Store) -> None: await _AsyncArrayProxy(a)[:16, :16].set(np.zeros((16, 16))) await _AsyncArrayProxy(a)[:16, :16].set(data) assert np.array_equal(await _AsyncArrayProxy(a)[:16, :16].get(), data) - assert await store.get(f"{path}/c0/0", prototype=default_buffer_prototype()) is None + assert await store.get(f"{path}/c0/0") is None @pytest.mark.parametrize("store", ["local", "memory"], indirect=["store"]) @@ -289,7 +288,7 @@ async def test_dimension_names(store: Store) -> None: assert isinstance(meta := (await AsyncArray.open(spath2)).metadata, ArrayV3Metadata) assert meta.dimension_names is None - zarr_json_buffer = await store.get(f"{path2}/zarr.json", prototype=default_buffer_prototype()) + zarr_json_buffer = await store.get(f"{path2}/zarr.json") assert zarr_json_buffer is not None assert "dimension_names" not in json.loads(zarr_json_buffer.to_bytes()) @@ -351,15 +350,15 @@ async def test_resize(store: Store) -> None: ) await _AsyncArrayProxy(a)[:16, :18].set(data) - assert await store.get(f"{path}/1.1", prototype=default_buffer_prototype()) is not None - assert await store.get(f"{path}/0.0", prototype=default_buffer_prototype()) is not None - assert await store.get(f"{path}/0.1", prototype=default_buffer_prototype()) is not None - assert await store.get(f"{path}/1.0", prototype=default_buffer_prototype()) is not None + assert await store.get(f"{path}/1.1") is not None + assert await store.get(f"{path}/0.0") is not None + assert await store.get(f"{path}/0.1") is not None + assert await store.get(f"{path}/1.0") is not None await a.resize((10, 12)) assert a.metadata.shape == (10, 12) assert a.shape == (10, 12) - assert await store.get(f"{path}/0.0", prototype=default_buffer_prototype()) is not None - assert await store.get(f"{path}/0.1", prototype=default_buffer_prototype()) is not None - assert await store.get(f"{path}/1.0", prototype=default_buffer_prototype()) is None - assert await store.get(f"{path}/1.1", prototype=default_buffer_prototype()) is None + assert await store.get(f"{path}/0.0") is not None + assert await store.get(f"{path}/0.1") is not None + assert await store.get(f"{path}/1.0") is None + assert await store.get(f"{path}/1.1") is None diff --git a/tests/test_codecs/test_sharding.py b/tests/test_codecs/test_sharding.py index 7eb4deccbf..11fae5fbbc 100644 --- a/tests/test_codecs/test_sharding.py +++ b/tests/test_codecs/test_sharding.py @@ -17,7 +17,7 @@ ShardingCodecIndexLocation, TransposeCodec, ) -from zarr.core.buffer import NDArrayLike, default_buffer_prototype +from zarr.core.buffer import NDArrayLike from zarr.errors import ZarrUserWarning from zarr.storage import StorePath, ZipStore @@ -398,8 +398,8 @@ async def test_delete_empty_shards(store: Store) -> None: data = np.ones((16, 16), dtype="uint16") data[:8, :8] = 0 assert np.array_equal(data, await _AsyncArrayProxy(a)[:, :].get()) - assert await store.get(f"{path}/c/1/0", prototype=default_buffer_prototype()) is None - chunk_bytes = await store.get(f"{path}/c/0/0", prototype=default_buffer_prototype()) + assert await store.get(f"{path}/c/1/0") is None + chunk_bytes = await store.get(f"{path}/c/0/0") assert chunk_bytes is not None assert len(chunk_bytes) == 16 * 2 + 8 * 8 * 2 + 4 diff --git a/tests/test_experimental/test_cache_store.py b/tests/test_experimental/test_cache_store.py index d4a45f78f1..26a0a9931e 100644 --- a/tests/test_experimental/test_cache_store.py +++ b/tests/test_experimental/test_cache_store.py @@ -8,7 +8,6 @@ import pytest from zarr.abc.store import Store -from zarr.core.buffer.core import default_buffer_prototype from zarr.core.buffer.cpu import Buffer as CPUBuffer from zarr.experimental.cache_store import CacheStore from zarr.storage import MemoryStore @@ -43,7 +42,7 @@ async def test_basic_caching(self, cached_store: CacheStore, source_store: Store assert await cached_store._cache.exists("test_key") # Retrieve and verify caching works - result = await cached_store.get("test_key", default_buffer_prototype()) + result = await cached_store.get("test_key") assert result is not None assert result.to_bytes() == b"test data" @@ -56,7 +55,7 @@ async def test_cache_miss_and_population( await source_store.set("source_key", test_data) # First access should miss cache but populate it - result = await cached_store.get("source_key", default_buffer_prototype()) + result = await cached_store.get("source_key") assert result is not None assert result.to_bytes() == b"source data" @@ -91,7 +90,7 @@ async def test_cache_expiration(self) -> None: # Skip freshness check if method doesn't exist await asyncio.sleep(1.1) # Just verify the data is still accessible - result = await cached_store.get("expire_key", default_buffer_prototype()) + result = await cached_store.get("expire_key") assert result is not None async def test_cache_set_data_false(self, source_store: Store, cache_store: Store) -> None: @@ -170,7 +169,7 @@ async def test_stale_cache_refresh(self) -> None: await source_store.set("refresh_key", new_data) # Access should refresh from source when cache is stale - result = await cached_store.get("refresh_key", default_buffer_prototype()) + result = await cached_store.get("refresh_key") assert result is not None assert result.to_bytes() == b"new data" @@ -204,7 +203,7 @@ async def test_cache_returns_cached_data_for_performance( cached_store.key_insert_times["orphan_key"] = time.monotonic() # Cache should return data for performance (no source verification) - result = await cached_store.get("orphan_key", default_buffer_prototype()) + result = await cached_store.get("orphan_key") assert result is not None assert result.to_bytes() == b"orphaned data" @@ -230,7 +229,7 @@ async def test_cache_coherency_through_expiration(self) -> None: await source_store.delete("coherency_key") # Cache should still return cached data (performance optimization) - result = await cached_store.get("coherency_key", default_buffer_prototype()) + result = await cached_store.get("coherency_key") assert result is not None assert result.to_bytes() == b"original data" @@ -238,7 +237,7 @@ async def test_cache_coherency_through_expiration(self) -> None: await asyncio.sleep(1.1) # Now stale cache should be refreshed from source - result = await cached_store.get("coherency_key", default_buffer_prototype()) + result = await cached_store.get("coherency_key") assert result is None # Key no longer exists in source async def test_cache_info(self, cached_store: CacheStore) -> None: @@ -447,7 +446,7 @@ async def test_get_nonexistent_key(self) -> None: cached_store = CacheStore(source_store, cache_store=cache_store) # Try to get nonexistent key - result = await cached_store.get("nonexistent", default_buffer_prototype()) + result = await cached_store.get("nonexistent") assert result is None # Should not create any cache entries @@ -544,7 +543,7 @@ async def test_get_no_cache_delete_tracking(self) -> None: assert "phantom_key" in cached_store.key_insert_times # Now try to get it - since it's not in source, should clean up tracking - result = await cached_store._get_no_cache("phantom_key", default_buffer_prototype()) + result = await cached_store._get_no_cache("phantom_key") assert result is None # Should have cleaned up tracking @@ -626,7 +625,7 @@ async def test_concurrent_get_and_evict(self) -> None: # Concurrent: read key1 while adding key3 (triggers eviction) async def read_key() -> None: for _ in range(100): - await cached_store.get("key1", default_buffer_prototype()) + await cached_store.get("key1") async def write_key() -> None: for i in range(10): @@ -802,11 +801,11 @@ async def test_cache_stats_method(self) -> None: await source_store.set("key1", buffer) # First get is a miss (not in cache yet) - result1 = await cached_store.get("key1", default_buffer_prototype()) + result1 = await cached_store.get("key1") assert result1 is not None # Second get is a hit (now in cache) - result2 = await cached_store.get("key1", default_buffer_prototype()) + result2 = await cached_store.get("key1") assert result2 is not None stats = cached_store.cache_stats() diff --git a/tests/test_group.py b/tests/test_group.py index 6f1f4e68fa..471aa91336 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -20,9 +20,9 @@ import zarr.storage from zarr import Array, AsyncArray, AsyncGroup, Group from zarr.abc.store import Store +from zarr.buffer import cpu from zarr.core import sync_group from zarr.core._info import GroupInfo -from zarr.core.buffer import default_buffer_prototype from zarr.core.config import config as zarr_config from zarr.core.dtype.common import unpack_dtype_json from zarr.core.dtype.npy.int import UInt8 @@ -196,7 +196,7 @@ def test_group_members(store: Store, zarr_format: ZarrFormat, consolidated_metad sync( store.set( f"{path}/extra_object-1", - default_buffer_prototype().buffer.from_bytes(b"000000"), + cpu.Buffer.from_bytes(b"000000"), ) ) # add an extra object under a directory-like prefix in the domain of the group. @@ -205,7 +205,7 @@ def test_group_members(store: Store, zarr_format: ZarrFormat, consolidated_metad sync( store.set( f"{path}/extra_directory/extra_object-2", - default_buffer_prototype().buffer.from_bytes(b"000000"), + cpu.Buffer.from_bytes(b"000000"), ) ) @@ -1481,12 +1481,10 @@ def test_open_mutable_mapping_sync(): async def test_open_ambiguous_node(): - zarr_json_bytes = default_buffer_prototype().buffer.from_bytes( + zarr_json_bytes = cpu.Buffer.from_bytes( json.dumps({"zarr_format": 3, "node_type": "group"}).encode("utf-8") ) - zgroup_bytes = default_buffer_prototype().buffer.from_bytes( - json.dumps({"zarr_format": 2}).encode("utf-8") - ) + zgroup_bytes = cpu.Buffer.from_bytes(json.dumps({"zarr_format": 2}).encode("utf-8")) store: dict[str, Buffer] = {"zarr.json": zarr_json_bytes, ".zgroup": zgroup_bytes} with pytest.warns( ZarrUserWarning, diff --git a/tests/test_indexing.py b/tests/test_indexing.py index c0bf7dd270..631f364e32 100644 --- a/tests/test_indexing.py +++ b/tests/test_indexing.py @@ -12,7 +12,6 @@ import zarr from zarr import Array -from zarr.core.buffer import default_buffer_prototype from zarr.core.indexing import ( BasicSelection, CoordinateSelection, @@ -155,7 +154,7 @@ def test_get_basic_selection_0d(store: StorePath, use_out: bool, value: Any, dty if use_out: # test out param - b = default_buffer_prototype().nd_buffer.from_numpy_array(np.zeros_like(arr_np)) + b = get_ndbuffer_class().from_numpy_array(np.zeros_like(arr_np)) arr_z.get_basic_selection(Ellipsis, out=b) assert_array_equal(arr_np, b.as_ndarray_like()) @@ -268,9 +267,7 @@ def _test_get_basic_selection( assert_array_equal(expect, actual) # test out param - b = default_buffer_prototype().nd_buffer.from_numpy_array( - np.empty(shape=expect.shape, dtype=expect.dtype) - ) + b = get_ndbuffer_class().from_numpy_array(np.empty(shape=expect.shape, dtype=expect.dtype)) z.get_basic_selection(selection, out=b) assert_array_equal(expect, b.as_numpy_array()) diff --git a/tests/test_metadata/test_consolidated.py b/tests/test_metadata/test_consolidated.py index 9e8b763ef7..b375314521 100644 --- a/tests/test_metadata/test_consolidated.py +++ b/tests/test_metadata/test_consolidated.py @@ -17,7 +17,7 @@ open, open_consolidated, ) -from zarr.core.buffer import cpu, default_buffer_prototype +from zarr.core.buffer import cpu from zarr.core.dtype import parse_dtype from zarr.core.group import ConsolidatedMetadata, GroupMetadata from zarr.core.metadata import ArrayV3Metadata @@ -192,9 +192,7 @@ async def test_consolidated(self, memory_store_with_hierarchy: Store) -> None: group4 = await open_consolidated(store=memory_store_with_hierarchy) assert group4.metadata == expected - buf = await memory_store_with_hierarchy.get( - "zarr.json", prototype=default_buffer_prototype() - ) + buf = await memory_store_with_hierarchy.get("zarr.json") assert buf is not None result_raw = json.loads(buf.to_bytes())["consolidated_metadata"] @@ -724,7 +722,7 @@ async def test_consolidated_metadata_encodes_special_chars( await zarr.api.asynchronous.consolidate_metadata(memory_store) root = await group(store=memory_store, zarr_format=zarr_format) - root_buffer = root.metadata.to_buffer_dict(default_buffer_prototype()) + root_buffer = root.metadata.to_buffer_dict(cpu.Buffer) if zarr_format == 2: root_metadata = json.loads(root_buffer[".zmetadata"].to_bytes().decode("utf-8"))["metadata"] diff --git a/tests/test_metadata/test_v2.py b/tests/test_metadata/test_v2.py index 8c3082e924..69fb8298ef 100644 --- a/tests/test_metadata/test_v2.py +++ b/tests/test_metadata/test_v2.py @@ -9,7 +9,6 @@ import zarr.api.asynchronous import zarr.storage from zarr.core.buffer import cpu -from zarr.core.buffer.core import default_buffer_prototype from zarr.core.dtype.npy.float import Float32, Float64 from zarr.core.dtype.npy.int import Int16 from zarr.core.group import ConsolidatedMetadata, GroupMetadata @@ -318,9 +317,7 @@ def test_zstd_checksum() -> None: compressors=compressor_config, zarr_format=2, ) - metadata = json.loads( - arr.metadata.to_buffer_dict(default_buffer_prototype())[".zarray"].to_bytes() - ) + metadata = json.loads(arr.metadata.to_buffer_dict()[".zarray"].to_bytes()) assert "checksum" not in metadata["compressor"] diff --git a/tests/test_metadata/test_v3.py b/tests/test_metadata/test_v3.py index 01ed921053..f0fd34accb 100644 --- a/tests/test_metadata/test_v3.py +++ b/tests/test_metadata/test_v3.py @@ -9,7 +9,6 @@ from zarr import consolidate_metadata, create_group from zarr.codecs.bytes import BytesCodec -from zarr.core.buffer import default_buffer_prototype from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding from zarr.core.config import config from zarr.core.dtype import UInt8, get_data_type_from_native_dtype @@ -261,7 +260,7 @@ def test_metadata_to_dict( def test_json_indent(indent: int) -> None: with config.set({"json_indent": indent}): m = GroupMetadata() - d = m.to_buffer_dict(default_buffer_prototype())["zarr.json"].to_bytes() + d = m.to_buffer_dict()["zarr.json"].to_bytes() assert d == json.dumps(json.loads(d), indent=indent).encode() @@ -283,7 +282,7 @@ async def test_datetime_metadata(fill_value: int, precision: Literal["ns", "D"]) } metadata = ArrayV3Metadata.from_dict(metadata_dict) # ensure there isn't a TypeError here. - d = metadata.to_buffer_dict(default_buffer_prototype()) + d = metadata.to_buffer_dict() result = json.loads(d["zarr.json"].to_bytes()) assert result["fill_value"] == fill_value @@ -321,7 +320,7 @@ async def test_special_float_fill_values(fill_value: str) -> None: "fill_value": fill_value, # this is not a valid fill value for uint8 } m = ArrayV3Metadata.from_dict(metadata_dict) - d = json.loads(m.to_buffer_dict(default_buffer_prototype())["zarr.json"].to_bytes()) + d = json.loads(m.to_buffer_dict()["zarr.json"].to_bytes()) assert m.fill_value is not None if fill_value == "NaN": assert np.isnan(m.fill_value) diff --git a/tests/test_properties.py b/tests/test_properties.py index 705cfd1b59..ecd9520743 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -6,8 +6,6 @@ import pytest from numpy.testing import assert_array_equal -from zarr.core.buffer import default_buffer_prototype - pytest.importorskip("hypothesis") import hypothesis.extra.numpy as npst @@ -91,15 +89,9 @@ def test_array_creates_implicit_groups(array): for i in range(len(ancestry)): parent = "/".join(ancestry[: i + 1]) if array.metadata.zarr_format == 2: - assert ( - sync(array.store.get(f"{parent}/.zgroup", prototype=default_buffer_prototype())) - is not None - ) + assert sync(array.store.get(f"{parent}/.zgroup")) is not None elif array.metadata.zarr_format == 3: - assert ( - sync(array.store.get(f"{parent}/zarr.json", prototype=default_buffer_prototype())) - is not None - ) + assert sync(array.store.get(f"{parent}/zarr.json")) is not None # this decorator removes timeout; not ideal but it should avoid intermittent CI failures @@ -211,10 +203,10 @@ async def test_roundtrip_array_metadata_from_store( prefixed with "0/", and then reads them back. The test asserts that each retrieved buffer exactly matches the original buffer. """ - asdict = meta.to_buffer_dict(prototype=default_buffer_prototype()) + asdict = meta.to_buffer_dict() for key, expected in asdict.items(): await store.set(f"0/{key}", expected) - actual = await store.get(f"0/{key}", prototype=default_buffer_prototype()) + actual = await store.get(f"0/{key}") assert actual == expected @@ -236,7 +228,7 @@ def test_roundtrip_array_metadata_from_json(data: st.DataObject, zarr_format: in cases like NaN, Infinity, complex numbers, and datetime values). """ metadata = data.draw(array_metadata(zarr_formats=st.just(zarr_format))) - buffer_dict = metadata.to_buffer_dict(prototype=default_buffer_prototype()) + buffer_dict = metadata.to_buffer_dict() if zarr_format == 2: zarray_dict = json.loads(buffer_dict[ZARRAY_JSON].to_bytes().decode()) diff --git a/tests/test_store/test_fsspec.py b/tests/test_store/test_fsspec.py index a2c07b7ed1..ed32815459 100644 --- a/tests/test_store/test_fsspec.py +++ b/tests/test_store/test_fsspec.py @@ -12,7 +12,7 @@ import zarr.api.asynchronous from zarr import Array from zarr.abc.store import OffsetByteRequest -from zarr.core.buffer import Buffer, cpu, default_buffer_prototype +from zarr.core.buffer import Buffer, cpu from zarr.core.sync import _collect_aiterator, sync from zarr.errors import ZarrUserWarning from zarr.storage import FsspecStore @@ -119,11 +119,11 @@ async def test_basic() -> None: data = b"hello" await store.set("foo", cpu.Buffer.from_bytes(data)) assert await store.exists("foo") - buf = await store.get("foo", prototype=default_buffer_prototype()) + buf = await store.get("foo") assert buf is not None assert buf.to_bytes() == data out = await store.get_partial_values( - prototype=default_buffer_prototype(), key_ranges=[("foo", OffsetByteRequest(1))] + key_ranges=[("foo", OffsetByteRequest(1))], prototype=cpu.Buffer ) assert out[0] is not None assert out[0].to_bytes() == data[1:] diff --git a/tests/test_store/test_logging.py b/tests/test_store/test_logging.py index fa566e45aa..46592f28c5 100644 --- a/tests/test_store/test_logging.py +++ b/tests/test_store/test_logging.py @@ -6,7 +6,7 @@ import pytest import zarr -from zarr.core.buffer import Buffer, cpu, default_buffer_prototype +from zarr.buffer import cpu from zarr.storage import LocalStore, LoggingStore from zarr.testing.store import StoreTests @@ -14,6 +14,7 @@ from pathlib import Path from zarr.abc.store import Store + from zarr.core.buffer.core import Buffer class StoreKwargs(TypedDict): @@ -68,7 +69,7 @@ async def test_default_handler( logging.getLogger().removeHandler(h) # Test logs are sent to stdout wrapped = LoggingStore(store=local_store) - buffer = default_buffer_prototype().buffer + buffer = cpu.Buffer res = await wrapped.set("foo/bar/c/0", buffer.from_bytes(b"\x01\x02\x03\x04")) # type: ignore[func-returns-value] assert res is None captured = capsys.readouterr() @@ -90,7 +91,7 @@ def test_is_open_setter_raises(self, store: LoggingStore[LocalStore]) -> None: @pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) async def test_logging_store(store: Store, caplog: pytest.LogCaptureFixture) -> None: wrapped = LoggingStore(store=store, log_level="DEBUG") - buffer = default_buffer_prototype().buffer + buffer = cpu.Buffer caplog.clear() res = await wrapped.set("foo/bar/c/0", buffer.from_bytes(b"\x01\x02\x03\x04")) # type: ignore[func-returns-value] diff --git a/tests/test_store/test_wrapper.py b/tests/test_store/test_wrapper.py index c5f2240297..c6ccf3cbb9 100644 --- a/tests/test_store/test_wrapper.py +++ b/tests/test_store/test_wrapper.py @@ -4,7 +4,7 @@ import pytest -from zarr.abc.store import BufferLike, ByteRequest, Store +from zarr.abc.store import BufferClassLike, ByteRequest, Store from zarr.core.buffer import Buffer from zarr.core.buffer.cpu import Buffer as CPUBuffer from zarr.core.buffer.cpu import buffer_prototype @@ -111,7 +111,7 @@ class NoisyGetter(WrapperStore[Any]): async def get( self, key: str, - prototype: BufferLike | None = None, + prototype: BufferClassLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: print(f"getting {key}") diff --git a/tests/test_store/test_zip.py b/tests/test_store/test_zip.py index 744ee82945..a955f5fd22 100644 --- a/tests/test_store/test_zip.py +++ b/tests/test_store/test_zip.py @@ -11,7 +11,7 @@ import zarr from zarr import create_array -from zarr.core.buffer import Buffer, cpu, default_buffer_prototype +from zarr.core.buffer import Buffer, cpu from zarr.core.group import Group from zarr.storage import ZipStore from zarr.testing.store import StoreTests @@ -42,7 +42,7 @@ def store_kwargs(self) -> dict[str, str | bool]: return {"path": temp_path, "mode": "w", "read_only": False} async def get(self, store: ZipStore, key: str) -> Buffer: - buf = store._get(key, prototype=default_buffer_prototype()) + buf = store._get(key, prototype=cpu.Buffer) assert buf is not None return buf diff --git a/tests/test_v2.py b/tests/test_v2.py index cb990f6159..f06f5ab89e 100644 --- a/tests/test_v2.py +++ b/tests/test_v2.py @@ -13,7 +13,6 @@ import zarr.storage from zarr import config from zarr.abc.store import Store -from zarr.core.buffer.core import default_buffer_prototype from zarr.core.dtype import FixedLengthUTF32, Structured, VariableLengthUTF8 from zarr.core.dtype.npy.bytes import NullTerminatedBytes from zarr.core.dtype.wrapper import ZDType @@ -78,7 +77,7 @@ async def test_v2_encode_decode( name="foo", shape=(3,), chunks=(3,), dtype=dtype, fill_value=fill_value, compressor=None ) - result = await store.get("foo/.zarray", zarr.core.buffer.default_buffer_prototype()) + result = await store.get("foo/.zarray") assert result is not None serialized = json.loads(result.to_bytes()) @@ -183,7 +182,7 @@ def test_v2_non_contiguous(numpy_order: Literal["C", "F"], zarr_order: Literal[" arr[6:9, 3:6] = a[6:9, 3:6] # The slice on the RHS is important np.testing.assert_array_equal(arr[6:9, 3:6], a[6:9, 3:6]) - buf = sync(store.get("2.1", default_buffer_prototype())) + buf = sync(store.get("2.1")) assert buf is not None np.testing.assert_array_equal( a[6:9, 3:6], From 1f0322f815ce2c0a7ced0a033b792159c07d7120 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Mon, 26 Jan 2026 16:25:45 +0100 Subject: [PATCH 12/13] remove incorrect release note and add one for the changes here --- changes/3638.feature.md | 1 - changes/3644.feature.md | 5 +++++ 2 files changed, 5 insertions(+), 1 deletion(-) delete mode 100644 changes/3638.feature.md create mode 100644 changes/3644.feature.md diff --git a/changes/3638.feature.md b/changes/3638.feature.md deleted file mode 100644 index ad2276fd51..0000000000 --- a/changes/3638.feature.md +++ /dev/null @@ -1 +0,0 @@ -Add methods for reading stored objects as bytes and JSON-decoded bytes to store classes. \ No newline at end of file diff --git a/changes/3644.feature.md b/changes/3644.feature.md new file mode 100644 index 0000000000..196676d7aa --- /dev/null +++ b/changes/3644.feature.md @@ -0,0 +1,5 @@ +The `Store.get` and `Store.get_partial_values` methods now accept `None` as the `prototype` argument. When `prototype` is `None`, stores will use their default buffer class (typically `zarr.core.buffer.cpu.Buffer`). This simplifies the API for common use cases where the default buffer is sufficient. + +A new type alias `BufferClassLike` has been added, which accepts either a `Buffer` class or a `BufferPrototype` instance. + +**Breaking change for third-party store implementations:** If you have implemented a custom `Store` subclass, you must update your `get` and `get_partial_values` methods to handle `prototype=None`. To do this, override the `_get_default_buffer_class` method to return an appropriate default `Buffer` class, and update your method signatures to accept `BufferClassLike | None` instead of `BufferPrototype`. When `prototype` is `None`, call `self._get_default_buffer_class()` to obtain the buffer class. If `prototype` is a `BufferPrototype` instance, extract the buffer class via `prototype.buffer`. From c8061bf04c9a4748e13188931014d6d59c491f24 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Mon, 26 Jan 2026 16:28:47 +0100 Subject: [PATCH 13/13] fix mysterious linting errors --- src/zarr/core/_info.py | 2 +- src/zarr/core/dtype/common.py | 4 ++-- src/zarr/core/dtype/npy/structured.py | 2 +- src/zarr/core/dtype/wrapper.py | 6 +++--- tests/test_dtype/test_wrapper.py | 6 +++--- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/zarr/core/_info.py b/src/zarr/core/_info.py index fef424346a..0d8ea8fe70 100644 --- a/src/zarr/core/_info.py +++ b/src/zarr/core/_info.py @@ -69,7 +69,7 @@ def byte_info(size: int) -> str: @dataclasses.dataclass(kw_only=True, frozen=True, slots=True) -class ArrayInfo: +class ArrayInfo: # type: ignore[misc] """ Visual summary for an Array. diff --git a/src/zarr/core/dtype/common.py b/src/zarr/core/dtype/common.py index 6b70f595ba..c1590e7c46 100644 --- a/src/zarr/core/dtype/common.py +++ b/src/zarr/core/dtype/common.py @@ -125,7 +125,7 @@ def check_dtype_spec_v2(data: object) -> TypeGuard[DTypeSpec_V2]: DTypeSpec_V3 = str | NamedConfig[str, Mapping[str, object]] -def check_dtype_spec_v3(data: object) -> TypeGuard[DTypeSpec_V3]: +def check_dtype_spec_v3(data: object) -> TypeGuard[DTypeSpec_V3]: # type: ignore[valid-type] """ Type guard for narrowing the type of a python object to an instance of DTypeSpec_V3, i.e either a string or a dict with a "name" field that's a string and a @@ -141,7 +141,7 @@ def check_dtype_spec_v3(data: object) -> TypeGuard[DTypeSpec_V3]: return False -def unpack_dtype_json(data: DTypeSpec_V2 | DTypeSpec_V3) -> DTypeJSON: +def unpack_dtype_json(data: DTypeSpec_V2 | DTypeSpec_V3) -> DTypeJSON: # type: ignore[valid-type] """ Return the array metadata form of the dtype JSON representation. For the Zarr V3 form of dtype metadata, this is a no-op. For the Zarr V2 form of dtype metadata, this unpacks the dtype name. diff --git a/src/zarr/core/dtype/npy/structured.py b/src/zarr/core/dtype/npy/structured.py index 8bedee07ef..09c9c5b135 100644 --- a/src/zarr/core/dtype/npy/structured.py +++ b/src/zarr/core/dtype/npy/structured.py @@ -317,7 +317,7 @@ def to_json(self, zarr_format: ZarrFormat) -> StructuredJSON_V2 | StructuredJSON elif zarr_format == 3: v3_unstable_dtype_warning(self) fields = [ - [f_name, f_dtype.to_json(zarr_format=zarr_format)] # type: ignore[list-item] + [f_name, f_dtype.to_json(zarr_format=zarr_format)] for f_name, f_dtype in self.fields ] base_dict = { diff --git a/src/zarr/core/dtype/wrapper.py b/src/zarr/core/dtype/wrapper.py index fdc5f747f0..9fe41f8119 100644 --- a/src/zarr/core/dtype/wrapper.py +++ b/src/zarr/core/dtype/wrapper.py @@ -57,7 +57,7 @@ @dataclass(frozen=True, kw_only=True, slots=True) -class ZDType(ABC, Generic[TDType_co, TScalar_co]): +class ZDType(ABC, Generic[TDType_co, TScalar_co]): # type: ignore[misc] """ Abstract base class for wrapping native array data types, e.g. numpy dtypes @@ -169,10 +169,10 @@ def from_json(cls: type[Self], data: DTypeJSON, *, zarr_format: ZarrFormat) -> S def to_json(self, zarr_format: Literal[2]) -> DTypeSpec_V2: ... @overload - def to_json(self, zarr_format: Literal[3]) -> DTypeSpec_V3: ... + def to_json(self, zarr_format: Literal[3]) -> DTypeSpec_V3: ... # type: ignore[valid-type] @abstractmethod - def to_json(self, zarr_format: ZarrFormat) -> DTypeSpec_V2 | DTypeSpec_V3: + def to_json(self, zarr_format: ZarrFormat) -> DTypeSpec_V2 | DTypeSpec_V3: # type: ignore[valid-type] """ Serialize this ZDType to JSON. diff --git a/tests/test_dtype/test_wrapper.py b/tests/test_dtype/test_wrapper.py index cc365e86d4..7e33920eec 100644 --- a/tests/test_dtype/test_wrapper.py +++ b/tests/test_dtype/test_wrapper.py @@ -80,7 +80,7 @@ class BaseTestZDType: valid_json_v2: ClassVar[tuple[DTypeSpec_V2, ...]] = () invalid_json_v2: ClassVar[tuple[str | dict[str, object] | list[object], ...]] = () - valid_json_v3: ClassVar[tuple[DTypeSpec_V3, ...]] = () + valid_json_v3: ClassVar[tuple[DTypeSpec_V3, ...]] = () # type: ignore[valid-type] invalid_json_v3: ClassVar[tuple[str | dict[str, object], ...]] = () # for testing scalar round-trip serialization, we need a tuple of (data type json, scalar json) @@ -120,9 +120,9 @@ def test_from_json_roundtrip_v2(self, valid_json_v2: DTypeSpec_V2) -> None: assert zdtype.to_json(zarr_format=2) == valid_json_v2 @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") - def test_from_json_roundtrip_v3(self, valid_json_v3: DTypeSpec_V3) -> None: + def test_from_json_roundtrip_v3(self, valid_json_v3: DTypeSpec_V3) -> None: # type: ignore[valid-type] zdtype = self.test_cls.from_json(valid_json_v3, zarr_format=3) - assert zdtype.to_json(zarr_format=3) == valid_json_v3 + assert zdtype.to_json(zarr_format=3) == valid_json_v3 # type: ignore[operator] def test_scalar_roundtrip_v2(self, scalar_v2_params: tuple[ZDType[Any, Any], Any]) -> None: zdtype, scalar_json = scalar_v2_params