xarray icon indicating copy to clipboard operation
xarray copied to clipboard

Add asynchronous load method

Open TomNicholas opened this issue 8 months ago • 3 comments

Adds an .async_load() method to Variable, which works by plumbing async get_duck_array all the way down until it finally gets to the async methods zarr v3 exposes.

Needs a lot of refactoring before it could be merged, but it works.

  • [x] Closes #10326
  • [x] Tests added
  • [x] User visible changes (including notable bug fixes) are documented in whats-new.rst
  • [x] New functions/methods are listed in api.rst

API:

  • [x] Variable.load_async
  • [x] DataArray.load_async
  • [x] Dataset.load_async
  • [ ] DataTree.load_async
  • [ ] load_dataset?
  • [ ] load_dataarray?

TomNicholas avatar May 16 '25 16:05 TomNicholas

These failing tests from the CI do not fail when I run them locally, which is interesting.

FAILED xarray/tests/test_backends.py::TestH5NetCDFViaDaskData::test_outer_indexing_reversed - ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4
FAILED xarray/tests/test_backends.py::TestNetCDF4ViaDaskData::test_outer_indexing_reversed - ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4
FAILED xarray/tests/test_backends.py::TestDask::test_outer_indexing_reversed - ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4
= 3 failed, 18235 passed, 1269 skipped, 77 xfailed, 15 xpassed, 2555 warnings in 487.15s (0:08:07) =
Error: Process completed with exit code 1.

TomNicholas avatar May 19 '25 02:05 TomNicholas

There is something funky going on when using .sel

# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "arraylake",
#     "yappi",
#     "zarr==3.0.8",
#     "xarray",
#     "icechunk"
# ]
#
# [tool.uv.sources]
# xarray = { git = "https://github.com/TomNicholas/xarray", rev = "async.load" }
# ///

import asyncio
from collections.abc import Iterable
from typing import TypeVar

import numpy as np

import xarray as xr

import zarr
from zarr.abc.store import ByteRequest, Store
from zarr.core.buffer import Buffer, BufferPrototype
from zarr.storage._wrapper import WrapperStore

T_Store = TypeVar("T_Store", bound=Store)


class LatencyStore(WrapperStore[T_Store]):
    """Works the same way as the zarr LoggingStore"""

    latency: float

    def __init__(
        self,
        store: T_Store,
        latency: float = 0.0,
    ) -> None:
        """
        Store wrapper that adds artificial latency to each get call.

        Parameters
        ----------
        store : Store
            Store to wrap
        latency : float
            Amount of artificial latency to add to each get call, in seconds.
        """
        super().__init__(store)
        self.latency = latency

    def __str__(self) -> str:
        return f"latency-{self._store}"

    def __repr__(self) -> str:
        return f"LatencyStore({self._store.__class__.__name__}, '{self._store}', latency={self.latency})"

    async def get(
        self,
        key: str,
        prototype: BufferPrototype,
        byte_range: ByteRequest | None = None,
    ) -> Buffer | None:
        await asyncio.sleep(self.latency)
        return await self._store.get(
            key=key, prototype=prototype, byte_range=byte_range
        )

    async def get_partial_values(
        self,
        prototype: BufferPrototype,
        key_ranges: Iterable[tuple[str, ByteRequest | None]],
    ) -> list[Buffer | None]:
        await asyncio.sleep(self.latency)
        return await self._store.get_partial_values(
            prototype=prototype, key_ranges=key_ranges
        )


memorystore = zarr.storage.MemoryStore({})

shape = 5
X = np.arange(5) * 10
ds = xr.Dataset(
    {
        "data": xr.DataArray(
            np.zeros(shape),
            coords={"x": X},
        )
    }
)

ds.to_zarr(memorystore)


latencystore = LatencyStore(memorystore, latency=0.1)
ds = xr.open_zarr(latencystore, zarr_format=3, consolidated=False, chunks=None)

# no problem for any of these
asyncio.run(ds["data"][0].load_async())
asyncio.run(ds["data"].sel(x=10).load_async())
asyncio.run(ds["data"].sel(x=11, method="nearest").load_async())

# also fine
ds["data"].sel(x=[30, 40]).load()

# broken!
asyncio.run(ds["data"].sel(x=[30, 40]).load_async())

uv run that script gives:

Traceback (most recent call last):
  File "/Users/ian/tmp/async_error.py", line 109, in <module>
    asyncio.run(ds["data"].sel(x=[30, 40]).load_async())
  File "/Users/ian/miniforge3/envs/test/lib/python3.12/asyncio/runners.py", line 195, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/Users/ian/miniforge3/envs/test/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/miniforge3/envs/test/lib/python3.12/asyncio/base_events.py", line 691, in run_until_complete
    return future.result()
           ^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/dataarray.py", line 1165, in load_async
    ds = await temp_ds.load_async(**kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/dataset.py", line 578, in load_async
    await asyncio.gather(*coros)
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/variable.py", line 963, in load_async
    self._data = await async_to_duck_array(self._data, **kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/namedarray/pycompat.py", line 168, in async_to_duck_array
    return await data.async_get_duck_array()  # type: ignore[no-untyped-call, no-any-return]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/indexing.py", line 875, in async_get_duck_array
    await self._async_ensure_cached()
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/indexing.py", line 867, in _async_ensure_cached
    duck_array = await self.array.async_get_duck_array()
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/indexing.py", line 821, in async_get_duck_array
    return await self.array.async_get_duck_array()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/indexing.py", line 674, in async_get_duck_array
    array = await self.array.async_getitem(self.key)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/backends/zarr.py", line 248, in async_getitem
    return await indexing.async_explicit_indexing_adapter(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/indexing.py", line 1068, in async_explicit_indexing_adapter
    result = await raw_indexing_method(raw_key.tuple)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: object numpy.ndarray can't be used in 'await' expression

ianhi avatar May 22 '25 21:05 ianhi

Notes to self:

  • [ ] Try to consolidate indexing tests with those in test_variable.py, potentially by defining a subclass of Variable that only implements async methods
  • [x] Use create_test_data, write to a zarr (memory)store, and open lazily - this will help test decoding machinery.
  • [x] Raise informative error if you try to do o/v-indexing with a version of zarr that's too old? Or just fall back to blocking in that case...

TomNicholas avatar May 30 '25 16:05 TomNicholas

The changes needed in zarr-python have just been merged upstream (but not yet released).

TomNicholas avatar Jul 30 '25 18:07 TomNicholas

Upstream-dev tests are passing. But they all need to pass otherwise these new methods will error on zarr-python<=3.1.1.

TomNicholas avatar Jul 31 '25 09:07 TomNicholas

These test failures are actually non-deterministic. The test tries to do vectorized indexing, and expects an error about vectorized indexing not being supported to be raised. But the test sometimes fails because an error about orthogonal indexing not being supported is raised instead.

What seems to be happening is that for the exact same indexer, sometimes the indexing call goes through the vectorized indexing codepath first and sometimes it goes through the orthogonal indexing codepath first. I think in both cases it gives the same result, but the order of execution can differ.

This script replicates the behaviour on this branch. If you run it repeatedly you will find that the behaviour changes between runs, as the error raised is inconsistent.

#!/usr/bin/env python3
"""
Standalone reproducer for the flaky async test behavior.
"""

import asyncio
import xarray as xr
import zarr
from xarray.tests.test_dataset import create_test_data


async def test_flaky_behavior():
    """Reproduce the exact test scenario that shows flaky behavior."""
    
    # Create zarr store with format 3
    memorystore = zarr.storage.MemoryStore({})
    ds = create_test_data()
    ds.to_zarr(memorystore, zarr_format=3, consolidated=False)
    
    # Open the dataset
    ds = xr.open_zarr(memorystore, consolidated=False, chunks=None)
    
    # Create the exact same indexer as the failing test
    indexer = {
        "dim1": xr.Variable(data=[2, 3], dims="points"),
        "dim2": xr.Variable(data=[1, 3], dims="points"),
    }
    
    # Apply isel and try load_async
    try:
        await ds.isel(**indexer).load_async()
        print("ERROR: Should have raised NotImplementedError!")
    except NotImplementedError as e:
        error_msg = str(e)
        if "vectorized async indexing" in error_msg:
            print("VECTORIZED")
        elif "orthogonal async indexing" in error_msg:
            print("ORTHOGONAL")  
        else:
            print(f"OTHER: {error_msg}")


if __name__ == "__main__":
    asyncio.run(test_flaky_behavior())

This other script replicates similar behaviour on main. To reveal the use of different codepaths this second script requires inserting debugging print statements. If you run this repeatedly you will see the order of the print statements changes between runs.

#!/usr/bin/env python3
"""
Test sync .load() to see which indexing codepath is taken.
"""

import xarray as xr
import zarr
from xarray.tests.test_dataset import create_test_data


def test_sync_load():
    """Test with sync .load() instead of .load_async()"""
    
    # Create zarr store with format 3
    memorystore = zarr.storage.MemoryStore({})
    ds = create_test_data()
    ds.to_zarr(memorystore, zarr_format=3, consolidated=False)
    
    # Open the dataset
    ds = xr.open_zarr(memorystore, consolidated=False, chunks=None)
    
    # Create the exact same indexer as the failing test
    indexer = {
        "dim1": xr.Variable(data=[2, 3], dims="points"),
        "dim2": xr.Variable(data=[1, 3], dims="points"),
    }
    
    # Apply isel and load (sync)
    result = ds.isel(**indexer).load()
    print("SYNC_LOAD_COMPLETED")


if __name__ == "__main__":
    test_sync_load()
# need to add these debugging print statments
class ZarrArrayWrapper:
    def __getitem__(self, key):
        array = self._array
        if isinstance(key, indexing.BasicIndexer):
            print(f"DEBUG: SYNC BasicIndexer: {key}")
            method = self._getitem
        elif isinstance(key, indexing.VectorizedIndexer):
            print(f"DEBUG: SYNC VectorizedIndexer: {key}")
            method = self._vindex
        elif isinstance(key, indexing.OuterIndexer):
            print(f"DEBUG: SYNC OuterIndexer: {key}")
            method = self._oindex

I think this is somehow to do with variable or indexer ordering not being deterministic - which could be due to use of dicts internally perhaps?

I can hide this weirdness by simply changing my test to be happy with either error. But I don't know if this is indicative of a bug that needs to be fixed.

TomNicholas avatar Aug 11 '25 16:08 TomNicholas

which could be due to use of dicts internally perhaps?

dict is deterministic since python 3.7, what you're looking for is set.

Either way, the decision on whether or not to use basic, orthogonal, or vectorized indexing depends on the types of indexers you pass to. According to https://github.com/pydata/xarray/blob/54ac2fe225dee813b7d0bf729af56027841996d5/xarray/core/variable.py#L661-L672 the presence of two variable indexers with a single, common dimension should go into _broadcast_indexes_vectorized, which should not return outer indexers.

keewis avatar Aug 11 '25 16:08 keewis

the decision on whether or not to use basic, orthogonal, or vectorized indexing depends on the types of indexers you pass to.

I'm passing exactly the same indexers every time.

the presence of two variable indexers with a single, common dimension should go into _broadcast_indexes_vectorized, which should not return outer indexers.

It should, but apparently it doesn't always! If you run either of those scripts, you will see OuterIndexers are being created.

TomNicholas avatar Aug 11 '25 16:08 TomNicholas

changing my test to be happy with either error

As I thought, with this change applied (in https://github.com/pydata/xarray/pull/10327/commits/a7918e4f82aa7972bb448dcbf4110c14a1c0e930) now everything seems to be passing. (I don't think the warnings causing readthedocs or the upstream mypy failures are anything to do with this PR)

TomNicholas avatar Aug 11 '25 16:08 TomNicholas

I think I figured out why: create_test_data creates a dataset that has three data variables, two of which do not have both indexed dims. Thus, if these variables are indexed first you get the orthogonal index error (indexing along one dim is always basic or orthogonal indexing), while if the other variable is indexed first you get the vectorized index error.

keewis avatar Aug 11 '25 17:08 keewis

I think I figured out why: create_test_data creates a dataset that has three data variables, two of which do not have both indexed dims. Thus, if these variables are indexed first you get the orthogonal index error (indexing along one dim is always basic or orthogonal indexing), while if the other variable is indexed first you get the vectorized index error.

Riiiiiight, thank you.

So actually there's another way for me to dodge this problem in my test: just index into a single Variable instead of into a Dataset. Then there can't be a race condition between variables.

TomNicholas avatar Aug 11 '25 17:08 TomNicholas

you can also use a single-variable dataset, but yeah, that would eliminate the issue

keewis avatar Aug 11 '25 17:08 keewis

It would be good to wire this up to test_backends somehow, even if just for a few indexing + roundtrip tests. Do you see a path there

@dcherian I added a test of loading to test_backends.py, which was worthwhile, as it found a bug with IndexVariables (fixed in https://github.com/pydata/xarray/pull/10327/commits/b4a5a90c995fc63e3e82efb30afa823a351ad0bd).

There are some other indexing-related tests in test_backends.py that I could use for async tests, but they would all be effectively duplicated onto the Zarr test class to avoid having them run on the NetCDF test class too. I also just don't really understand what this indexing test does - how is that meant to load values at all?

TomNicholas avatar Aug 12 '25 11:08 TomNicholas

Try to consolidate indexing tests with those in test_variable.py, potentially by defining a subclass of Variable that only implements async methods

@dcherian I tried to implement this suggestion of yours, aiming for similar lazy indexing tests that aren't coupled to zarr. My idea was to do basically the same indexing test as I already have, but create an MyAsyncBackendArray(BackendArray) instead of using the ZarrArrayWrapper.

However, it seems that no-where else in xarray's tests do we have any tests that take this approach of defining a custom BackendArray subclass. The tests in test_variable.py::TestBackendIndexing call LazilyIndexedArray, but never BackendArray. When trying to subclass BackendArray I hit weird indexing errors I didn't understand, so I decided to just use LazilyIndexedArray instead like the existing tests. But now my new test is very similar to the existing tests, and I feel not sufficient to replace the zarr-dependent tests I already wrote. 😕

TomNicholas avatar Aug 12 '25 12:08 TomNicholas

I also removed test_async.py in favour of moving those tests (which are all zarr-dependent) into test_backends.py.

TomNicholas avatar Aug 12 '25 13:08 TomNicholas

how is that meant to load values at all?

assert_identical calls load to be able to compare values, which works in-place, i.e. afterwards actual is also in-memory

keewis avatar Aug 12 '25 14:08 keewis