sgkit icon indicating copy to clipboard operation
sgkit copied to clipboard

Dask reshape bug for arrays with fully chunked leading axes

Open ravwojdyla opened this issue 4 years ago • 1 comments

Long story short, the problematic line in our code is: https://github.com/pystatgen/sgkit/blob/41827f3fd116d59ab4dc8b119a15ad5f3be730b9/sgkit/stats/regenie.py#L364 https://github.com/dask/dask/pull/6748 is a special case optimisation:

When the slow-moving (early) axes in .reshape are all size 1

Our YP[i] happens to fall into that category. And afaiu https://github.com/dask/dask/pull/6748 introduced a bug, it might be hard to see that in the reginie code, here's a distilled reproduction:

> # c4038add1a087ba3a82207a557cdcad9753b689d is the https://github.com/dask/dask/pull/6748
> dask git:(c4038add) g ck c4038add1a087ba3a82207a557cdcad9753b689d
HEAD is now at c4038add Avoid rechunking in reshape with chunksize=1 (#6748)
> dask git:(c4038add) python
>>> import numpy as np
>>> import dask.array as da
>>> a = da.from_array(np.arange(24).reshape(2, 3, 4), chunks=((1, 1), (1, 1, 1), (4,)))
>>> a.reshape(6,4)
dask.array<reshape, shape=(6, 4), dtype=int64, chunksize=(1, 4), chunktype=numpy.ndarray>
>>> # merging dimensions at the front works fine, now let's try the last two (which is our use case)
>>> a.reshape(2,12)
dask.array<reshape, shape=(2, 24), dtype=int64, chunksize=(1, 4), chunktype=numpy.ndarray>
>>> # NOTICE: the shape is (2, 24) NOT (2,12)!, now let's try to compute this:
>>> a.reshape(2,12).compute()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/rav/projects/dask/dask/base.py", line 167, in compute
    (result,) = compute(self, traverse=False, **kwargs)
  File "/Users/rav/projects/dask/dask/base.py", line 454, in compute
    results = schedule(dsk, keys, **kwargs)
  File "/Users/rav/projects/dask/dask/threaded.py", line 76, in get
    results = get_async(
  File "/Users/rav/projects/dask/dask/local.py", line 503, in get_async
    return nested_get(result, state["cache"])
  File "/Users/rav/projects/dask/dask/local.py", line 299, in nested_get
    return tuple([nested_get(i, coll) for i in ind])
  File "/Users/rav/projects/dask/dask/local.py", line 299, in <listcomp>
    return tuple([nested_get(i, coll) for i in ind])
  File "/Users/rav/projects/dask/dask/local.py", line 299, in nested_get
    return tuple([nested_get(i, coll) for i in ind])
  File "/Users/rav/projects/dask/dask/local.py", line 299, in <listcomp>
    return tuple([nested_get(i, coll) for i in ind])
  File "/Users/rav/projects/dask/dask/local.py", line 299, in nested_get
    return tuple([nested_get(i, coll) for i in ind])
  File "/Users/rav/projects/dask/dask/local.py", line 299, in <listcomp>
    return tuple([nested_get(i, coll) for i in ind])
  File "/Users/rav/projects/dask/dask/local.py", line 301, in nested_get
    return coll[ind]
KeyError: ('reshape-aa29de8b0f6be5be25495836ed047c4a', 1, 0)

Notice the invalid shape after a.reshape(2,12).

The same code works fine without https://github.com/dask/dask/pull/6748:

> dask git:(c4038add) g ck head~1
Previous HEAD position was c4038add Avoid rechunking in reshape with chunksize=1 (#6748)
HEAD is now at 94bdd4e3 Try to make categoricals work on join (#6205)
> dask git:(94bdd4e3) python
>>> import numpy as np
>>> import dask.array as da
>>> a = da.from_array(np.arange(24).reshape(2, 3, 4), chunks=((1, 1), (1, 1, 1), (4,)))
>>> a.reshape(6,4)
dask.array<reshape, shape=(6, 4), dtype=int64, chunksize=(3, 4), chunktype=numpy.ndarray>
>>> a.reshape(2,12)
dask.array<reshape, shape=(2, 12), dtype=int64, chunksize=(1, 4), chunktype=numpy.ndarray>
>>> a.reshape(2,12).compute()
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
       [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]])

This issue seems to be affected by the chunking of the array (a), in the case above (and in our reginie case) the lower axis are completely chunked, see:

>>> a = da.from_array(np.arange(24).reshape(2, 3, 4), chunks=((1, 1), (1, 1, 1), (4,)))
>>> a.reshape(2,12)
dask.array<reshape, shape=(2, 24), dtype=int64, chunksize=(1, 4), chunktype=numpy.ndarray>
>>> # BAD

>>> a = da.from_array(np.arange(24).reshape(2, 3, 4), chunks=((1, 1), (1, 1, 1), (2,2)))
>>> a.reshape(2,12)
dask.array<reshape, shape=(2, 24), dtype=int64, chunksize=(1, 2), chunktype=numpy.ndarray>
>>> # BAD

>>> a = da.from_array(np.arange(24).reshape(2, 3, 4), chunks=((1, 1), (1, 1, 1), (1, 1, 1, 1)))
>>> a.reshape(2,12)
dask.array<reshape, shape=(2, 24), dtype=int64, chunksize=(1, 1), chunktype=numpy.ndarray>
>>> # BAD

>>> a = da.from_array(np.arange(24).reshape(2, 3, 4), chunks=((1, 1), (1, 2), (2, 2)))
>>> a.reshape(2,12)
dask.array<reshape, shape=(2, 12), dtype=int64, chunksize=(1, 4), chunktype=numpy.ndarray>
>>> # OK

Btw, this is an good example of how valuable Eric's asserts in this case are, it's already pretty hard to debug this code, and image if it just failed at compute() with a cryptic KeyError: ('reshape-aa29de8b0f6be5be25495836ed047c4a', 1, 0), so big +1 to https://github.com/pystatgen/sgkit/issues/267

Originally posted by @ravwojdyla in https://github.com/pystatgen/sgkit/issues/430#issuecomment-772868301

ravwojdyla avatar Feb 04 '21 17:02 ravwojdyla

This has now been fixed in Dask, so should be possible to check if it resolves the issue here.

tomwhite avatar Apr 06 '21 13:04 tomwhite