Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions changes/1806.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Add `zarr.abc.store.Store.get_many`, a bulk counterpart to `Store.get` that
retrieves many values — each a whole key or a `(key, byte_range)` pair — in a
single call. It generalizes `Store.get_ranges` (many ranges of one key) to many
keys, yielding `(request_index, Buffer | None)` batches in completion order so a
store can coalesce reads that land in the same underlying object. The method is
defined on the `Store` ABC with a default implementation that fetches the
requests concurrently with `Store.get`, so every store inherits a working
version; stores whose backend can retrieve many objects together should override
it (`FsspecStore` does, coalescing via `fsspec`'s `cat_ranges`). Coalescing
tuning is left to each store rather than exposed on the interface. This restores
and generalizes the batched-fetch capability of the v2 `getitems` Store API.
65 changes: 65 additions & 0 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,71 @@ async def get_partial_values(
"""
...

async def get_many(
self,
requests: Sequence[tuple[str, ByteRequest | None] | str],
*,
prototype: BufferPrototype,
) -> AsyncIterator[Sequence[tuple[int, Buffer | None]]]:
"""Retrieve many values, possibly from different keys, at once.
This is the bulk counterpart to :meth:`get`: the whole set of requests

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rst -> mkdocs

is handed to the store in a single call, so an implementation can fetch
them together — for example by coalescing reads that land in the same
underlying object into fewer requests — rather than one at a time. It
generalizes :meth:`get_ranges` (which reads many ranges from a *single*
key) to many keys, each with an optional byte range.
Yields one batch per underlying I/O operation, each a sequence of
``(request_index, Buffer | None)`` tuples where ``request_index`` is the
position of the request in ``requests``. Every request is reported
exactly once across all batches; a ``None`` buffer means that key is
absent. Batches arrive in completion order, not request order, so
callers use the indices to reassemble results.
The default implementation fetches each request concurrently with
:meth:`get`, so every store gets a working version for free; stores
whose backend can retrieve many objects together (e.g.
:class:`~zarr.storage.FsspecStore`, which coalesces nearby reads via
``fsspec``) should override it. Anything specific to *how* a store
batches or coalesces (concurrency limits, gap thresholds, ...) is an
implementation concern of that store, not part of this interface.
Parameters
----------
requests : Sequence[tuple[str, ByteRequest | None] | str]
The values to retrieve. Each request is either a bare key (the
whole value) or a ``(key, byte_range)`` tuple; a ``byte_range`` of
``None`` also means the whole value. A key may appear more than
once with different ranges.
prototype : BufferPrototype
The prototype of the output buffers. Stores may support a default
buffer prototype.
Yields
------
Sequence[tuple[int, Buffer | None]]
One batch per underlying I/O operation, each a sequence of
``(request_index, Buffer | None)`` tuples.
"""
# Local imports to avoid an import cycle at module load time.
from zarr.core.common import concurrent_map
from zarr.core.config import config

indexed = [
(i, req, None) if isinstance(req, str) else (i, req[0], req[1])
for i, req in enumerate(requests)
]

async def _fetch(
index: int, key: str, byte_range: ByteRequest | None
) -> tuple[int, Buffer | None]:
return index, await self.get(key, prototype, byte_range)

results = await concurrent_map(indexed, _fetch, config.get("async.concurrency"))
for result in results:
yield [result]

@abstractmethod
async def exists(self, key: str) -> bool:
"""Check if a key exists in the store.
Expand Down
19 changes: 18 additions & 1 deletion src/zarr/storage/_fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
logger = getLogger(__name__)

if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterable
from collections.abc import AsyncIterator, Iterable, Sequence

from fsspec import AbstractFileSystem
from fsspec.asyn import AsyncFileSystem
Expand Down Expand Up @@ -459,6 +459,23 @@ async def get_partial_values(

return [None if isinstance(r, Exception) else prototype.buffer.from_bytes(r) for r in res]

async def get_many(
self,
requests: Sequence[tuple[str, ByteRequest | None] | str],
*,
prototype: BufferPrototype,
) -> AsyncIterator[Sequence[tuple[int, Buffer | None]]]:
# docstring inherited
# Hand the whole set of requests to fsspec in one call so it can
# coalesce nearby reads into fewer requests (via
# ``AbstractFileSystem._cat_ranges``), rather than issuing one request
# per key as the default Store.get_many does. get_partial_values
# returns results aligned to the input order, so we can index them
# directly and yield them as a single completed batch.
key_ranges = [(req, None) if isinstance(req, str) else req for req in requests]
values = await self.get_partial_values(prototype, key_ranges)
yield list(enumerate(values))

async def list(self) -> AsyncIterator[str]:
# docstring inherited
allfiles = await self.fs._find(self.path, detail=False, withdirs=False)
Expand Down
37 changes: 36 additions & 1 deletion src/zarr/testing/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ async def test_get_raises(self, store: S) -> None:
with pytest.raises((ValueError, TypeError), match=r"Unexpected byte_range, got.*"):
await store.get("c/0", prototype=default_buffer_prototype(), byte_range=(0, 2)) # type: ignore[arg-type]

async def test_get_many(self, store: S) -> None:
async def test_get_many_streaming(self, store: S) -> None:
"""
Ensure that multiple keys can be retrieved at once with the _get_many method.
"""
Expand Down Expand Up @@ -407,6 +407,41 @@ async def test_get_partial_values(
obs.to_bytes() == exp.to_bytes() for obs, exp in zip(observed, expected, strict=True)
)

async def test_get_many(self, store: S) -> None:
# put a handful of whole values
for key, data in {"c/0/0": b"aaaaa", "c/0/1": b"bb", "c/0/2": b"cccc"}.items():
await self.set(store, key, self.buffer_cls.from_bytes(data))

# mix bare keys, an explicit (key, None) tuple, a partial range, and a
# missing key. Each request must be reported exactly once, by index.
requests: list[tuple[str, ByteRequest | None] | str] = [
"c/0/0",
("c/0/1", None),
("c/0/0", RangeByteRequest(1, 3)),
"c/0/2",
"c/0/missing",
]
collected: dict[int, Buffer | None] = {}
async for batch in store.get_many(requests, prototype=default_buffer_prototype()):
for index, value in batch:
assert index not in collected # reported exactly once
collected[index] = value

assert set(collected) == set(range(len(requests)))
assert collected[4] is None # missing key -> None (not omitted)
expected = {0: b"aaaaa", 1: b"bb", 2: b"aa", 3: b"cccc"} # index 2 is "aaaaa"[1:3]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we define expected here, why don't we just assert that expected == collected?

for index, want in expected.items():
buffer = collected[index]
assert buffer is not None
assert buffer.to_bytes() == want

async def test_get_many_empty(self, store: S) -> None:
# an empty request is valid and yields no results
batches = [
batch async for batch in store.get_many([], prototype=default_buffer_prototype())
]
assert [pair for batch in batches for pair in batch] == []

async def test_exists(self, store: S) -> None:
assert not await store.exists("foo")
await store.set("foo/zarr.json", self.buffer_cls.from_bytes(b"bar"))
Expand Down
Loading