Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changes/4001.misc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Restore sharding write performance for shards with many chunks. The
`subchunk_write_order` feature inadvertently rebuilt the per-shard chunk
coordinate grid (up to tens of thousands of tuples) on every partial write;
these coordinates are now cached, restoring throughput to its previous level.
24 changes: 18 additions & 6 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@
BasicIndexer,
ChunkProjection,
SelectorTuple,
_lexicographic_order,
_lexicographic_order_keys,
c_order_iter,
get_indexer,
lexicographic_order_iter,
morton_order_iter,
)
from zarr.core.metadata.v3 import (
Expand Down Expand Up @@ -266,13 +269,19 @@ def __iter__(self) -> Iterator[tuple[int, ...]]:
def to_dict_vectorized(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the following reads cleaner than adding another parameter. The reader already knows its own chunks_per_shard, so it can fetch both cached structures itself instead of having them threaded in. With this, the call site simplifies to just shard_reader.to_dict_vectorized().

I tested it still fixes the regression — same ~1.6–1.8× speedup over main on test_sharded_morton_write_single_chunk, and the difference vs the committed version is within run-to-run noise.

    def to_dict_vectorized(self) -dict[tuple[int, ...], Buffer | None]:
        """Build a dict of chunk coordinates to buffers using vectorized lookup.

        The full per-shard chunk coordinate grid (both the array used for the
        vectorized index lookup and the plain tuples used as dict keys) is
        cached on ``chunks_per_shard``, so neither is rebuilt on every call.
        For a shard with tens of thousands of chunks this avoids reconstructing
        that many tuples on every partial write.

        Returns
        -------
        dict mapping chunk coordinate tuples to Buffer or None
        """
        chunks_per_shard = self.index.chunks_per_shard
        chunk_coords_array = _lexicographic_order(chunks_per_shard)
        chunk_coords_keys = _lexicographic_order_keys(chunks_per_shard)
        starts, ends, valid = self.index.get_chunk_slices_vectorized(chunk_coords_array)

        result: dict[tuple[int, ...], Buffer | None] = {}
        for i, coords in enumerate(chunk_coords_keys):
            if valid[i]:
                result[coords] = self.buf[int(starts[i]) : int(ends[i])]
            else:
                result[coords] = None

        return result

self,
chunk_coords_array: npt.NDArray[np.integer[Any]],
chunk_coords_keys: tuple[tuple[int, ...], ...],
) -> dict[tuple[int, ...], Buffer | None]:
"""Build a dict of chunk coordinates to buffers using vectorized lookup.

Parameters
----------
chunk_coords_array : ndarray of shape (n_chunks, n_dims)
Array of chunk coordinates for vectorized index lookup.
chunk_coords_keys : tuple of coordinate tuples
The same coordinates as `chunk_coords_array`, in the same order, as
plain tuples for use as dict keys. Passed in (rather than derived
row-by-row from the array) so the cached value can be reused instead
of rebuilding 35k tuples on every write.

Returns
-------
Expand All @@ -281,11 +290,11 @@ def to_dict_vectorized(
starts, ends, valid = self.index.get_chunk_slices_vectorized(chunk_coords_array)

result: dict[tuple[int, ...], Buffer | None] = {}
for i, coords in enumerate(chunk_coords_array):
for i, coords in enumerate(chunk_coords_keys):
if valid[i]:
result[tuple(coords.ravel())] = self.buf[int(starts[i]) : int(ends[i])]
result[coords] = self.buf[int(starts[i]) : int(ends[i])]
else:
result[tuple(coords.ravel())] = None
result[coords] = None

return result

Expand Down Expand Up @@ -533,7 +542,7 @@ def _subchunk_order_iter(
case "morton":
subchunk_iter = morton_order_iter(chunks_per_shard)
case "lexicographic":
subchunk_iter = np.ndindex(chunks_per_shard)
subchunk_iter = lexicographic_order_iter(chunks_per_shard)
case "colexicographic":
subchunk_iter = (c[::-1] for c in np.ndindex(chunks_per_shard[::-1]))
case "unordered":
Expand Down Expand Up @@ -612,9 +621,12 @@ async def _encode_partial_single(
chunks_per_shard=chunks_per_shard,
)
shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard)
# Use vectorized lookup for better performance
# Use vectorized lookup for better performance. The lexicographic
# coordinate array and keys are cached, so neither is rebuilt on
# every write.
shard_dict = shard_reader.to_dict_vectorized(
np.array(list(self._subchunk_order_iter(chunks_per_shard, "lexicographic")))
_lexicographic_order(chunks_per_shard),
_lexicographic_order_keys(chunks_per_shard),
)

await self.codec_pipeline.write(
Expand Down
27 changes: 27 additions & 0 deletions src/zarr/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1584,6 +1584,33 @@ def morton_order_iter(chunk_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]
return iter(_morton_order_keys(tuple(chunk_shape)))


@lru_cache(maxsize=16)
def _lexicographic_order(chunk_shape: tuple[int, ...]) -> npt.NDArray[np.intp]:
# Lexicographic (C-order) coordinates, computed vectorized and cached so that
# the sharding codec's per-shard chunk grid is not rebuilt on every call.
# Equivalent to `np.array(list(np.ndindex(chunk_shape)))` but without the
# Python-level iteration over every coordinate.
n_dims = len(chunk_shape)
if n_dims == 0:
# A 0-d shard holds a single chunk addressed by the empty coordinate, so
# the coordinate array has one row and zero columns. np.indices(()) cannot
# express this, so build it directly. Matches list(np.ndindex(())) == [()].
order = np.empty((1, 0), dtype=np.intp)
else:
order = np.indices(chunk_shape, dtype=np.intp).reshape(n_dims, -1).T
order.flags.writeable = False
return order


@lru_cache(maxsize=16)
def _lexicographic_order_keys(chunk_shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]:
return tuple(tuple(int(x) for x in row) for row in _lexicographic_order(chunk_shape))


def lexicographic_order_iter(chunk_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]:
return iter(_lexicographic_order_keys(tuple(chunk_shape)))


def c_order_iter(chunks_per_shard: tuple[int, ...]) -> Iterator[tuple[int, ...]]:
return itertools.product(*(range(x) for x in chunks_per_shard))

Expand Down
Loading