xarray icon indicating copy to clipboard operation
xarray copied to clipboard

Unable to use dask.scatter with apply_ufunc

Open alessioarena opened this issue 1 year ago • 2 comments

What is your issue?

I am trying to scatter an large array and pass it as keyword argument to a function applied using apply_ufunc but that is currently not working. The same function works if providing the actual array, but if providing the Future linked to the scatter data the task fails.

Here is a minimal example to reproduce this issue

import dask.array as da
import xarray as xr
import numpy as np

data = xr.DataArray(data=da.random.random((15, 15, 20)), coords={'x': range(15), 'y': range(15), 'z': range(20)}, dims=('x', 'y', 'z'))

test = np.full((20,), 30)
test_future = client.scatter(test, broadcast=True)

def _copy_test(d, test=None):
    return test


new_data_actual = xr.apply_ufunc(
    _copy_test,
    data, 
    input_core_dims=[['z']],
    output_core_dims=[['new_z']],
    vectorize=True,
    dask='parallelized',
    output_dtypes="float64",
    kwargs={'test':test},
    dask_gufunc_kwargs = {'output_sizes':{'new_z':20}}
)

new_data_future = xr.apply_ufunc(
    _copy_test,
    data, 
    input_core_dims=[['z']],
    output_core_dims=[['new_z']],
    vectorize=True,
    dask='parallelized',
    output_dtypes="float64",
    kwargs={'test':test_future},
    dask_gufunc_kwargs = {'output_sizes':{'new_z':20}}
)

data[0, 0].compute()
#[0.3034994 , 0.08172002, 0.34731092, ...]

new_data_actual[0, 0].compute()
#[30.0, 30.0, 30.0, ...]

new_data_future[0,0].compute()
#KilledWorker

I tried different versions of this, going from explicitly calling test.result() to change the way the Future was passed, but nothing worked. I also tried to raise exceptions within the function and various way to print information, but that also did not work. This last issue makes me think that if passing a Future I actually don't get to the scope of that function

Am I trying to do something completely silly? or is this an unexpected behavior?

alessioarena avatar Jul 18 '22 07:07 alessioarena

This is still an issue. I noticed that the documentation of map_blocks states: kwargs (mapping) – Passed verbatim to func after unpacking. xarray objects, if any, will not be subset to blocks. Passing dask collections in kwargs is not allowed.

Is this the case for apply_ufunc as well? if yes than it is not documented. Is there another recommended way to pass data to workers without clogging the scheduler for this application?

alessioarena avatar Sep 28 '22 02:09 alessioarena

I think I may have narrowed down the problem to a limitation in dask using dask_gateway.

If passing a Future to a worker, the worker will try to unpickle that Future, and as part of that unpickle the Client object passed when creating such Future.

Unfortunately, in a dask_gateway context the client is behind a gateway connection that is not understood by the worker as normally does not have to deal with a gateway at all. In my case I do not get any error message, just the task failing and retrying over and over, but fiddling around I managed to get the same error as this post (https://stackoverflow.com/questions/70775315/scattering-data-to-dask-cluster-workers-unknown-address-scheme-gateway)

alessioarena avatar Oct 02 '22 01:10 alessioarena

This is still an issue. I noticed that the documentation of map_blocks states: kwargs (mapping) – Passed verbatim to func after unpacking. xarray objects, if any, will not be subset to blocks. Passing dask collections in kwargs is not allowed.

Is this the case for apply_ufunc as well?

test_future is not a dask collection. It's a distributed.Future, which points to an arbitrary, opaque data blob that xarray has no means to know about.

FWIW, I could reproduce the issue, where the future in the kwargs is not resolved to the data it points to as one would expect. Minimal reproducer:

import distributed
import xarray

client = distributed.Client(processes=False)
x = xarray.DataArray([1, 2]).chunk()
test_future = client.scatter("Hello World")


def f(d, test):
    print(test)
    return d


y = xarray.apply_ufunc(
    f,
    x,
    dask='parallelized',
    output_dtypes="float64",
    kwargs={'test':test_future},
)
y.compute()

Expected print output: Hello World Actual print output: <Future: finished, type: str, key: str-b012273bcde56eadf364cd3ce9b4ca26>

crusaderky avatar Oct 17 '22 11:10 crusaderky

I can add that this problem is augmented in a dask_gateway system where the task just fails.

With apply_ufunc I never received an error but in similar context I obtained something very similar to https://github.com/dask/dask-gateway/issues/404.

My interpretation is that the Future is resolved at the worker (or in case of apply_ufunc a thread of this worker) and embeds a reference to the Client object. This last however uses a gateway connection that is not understood by the worker as generally is the scheduler dealing with those

alessioarena avatar Oct 17 '22 11:10 alessioarena

Having said the above, your design is... contrived.

There isn't, as of today, a straightforward way to scatter a local dask collection (persist() will push the whole thing through the scheduler and likely send it out of memory).

Workaround:

test = np.full((20,), 30)
a = da.from_array(test)
dsk = client.scatter(dict(a.dask), broadcast=True)
a = da.Array(dsk, name=a.name, chunks=a.chunks, dtype=a.dtype, meta=a._meta, shape=a.shape)
a_x = xarray.DataArray(a, dims=["new_z"])

Once you have a_x, you just pass it to the args (not kwargs) of apply_ufunc.

crusaderky avatar Oct 17 '22 12:10 crusaderky

I'm not sure I understand the code above.

In my case I have an array of approximately 300k elements that each and every function call needs to have access. I can pass it as a kwargs in its numpy form, but once I scale up the calculation across a large dataset (many large chunks) such array gets replicated for every task pushing the scheduler out of memory.

That is why I tried to send the dataset to the cluster beforehand using scatter, but I cannot resolve the Future at the workers

alessioarena avatar Oct 17 '22 12:10 alessioarena

new_data_future = xr.apply_ufunc(
    _copy_test,
    data, 
    a_x,
    ...
)

instead of using kwargs.

I've opened https://github.com/dask/distributed/issues/7140 to simplify this. With it implemented, my snippet

test = np.full((20,), 30)
a = da.from_array(test)
dsk = client.scatter(dict(a.dask), broadcast=True)
a = da.Array(dsk, name=a.name, chunks=a.chunks, dtype=a.dtype, meta=a._meta, shape=a.shape)
a_x = xarray.DataArray(a, dims=["new_z"])

would become

test = np.full((20,), 30)
a_x = xarray.DataArray(test, dims=["new_z"]).chunk()
a_x = client.scatter(a_x)

crusaderky avatar Oct 17 '22 12:10 crusaderky

I will try that. I still find it weird that I need to wrap a numpy object into a task/xarray object to be able to send it to workers when there is dask.scatter made for exactly that purpose.

Thanks for opening that issue. I do feel there is the need to revisit scatter functionality and role particularly around dynamic clusters.

Having a better look at your initial comment, that may still work if you call Future.result() method inside the function applied. That in theory should retrieve the data associated with that Future, in that case "Hello World". However, in a dark gateway setup that will fail

alessioarena avatar Oct 17 '22 12:10 alessioarena

Not sure there's anything actionable here

dcherian avatar Dec 19 '23 05:12 dcherian

I think this thread is related to my problem, but not 100% sure.

I have a single xarray dataset (holding multiple dataarrays) which I want to load into worker memory across a dask cluster, and then I do a bunch of different computations on the same data.

I guess it's up to dask to work out how it wants to distribute the chunks across worker memory, but one scheme I imagine could be each worker loads N_chunks / N_workers number of chunks for each dataarray in the dataset. e.g. if there are 5 dataarray's in the dataset and each dataaray is 20 chunks, if there are 10 workers then each worker would load into memory 2 chunks from each dataarray.

Is this, or something like it, what a simple ds.persist() achieves? When I do this I get a warning from dask:

/home/ec2-user/miniforge3/envs/py311/lib/python3.11/site-packages/distributed/client.py:3162: UserWarning: Sending large graph of size 39.75 MiB.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.

(Note I have not done anything more than xr.open_mfdataset with chunking)

The data loading seems pretty slow, wondering if I should be heeding this warning and using scatter...?

ollie-bell avatar Jan 09 '24 18:01 ollie-bell