Dask missing linalg functions from array API standard
Currently dask is missing some functions in the linear algebra module, and the fft module is apparently noncompliant. I haven't double checked this list, but a list of unavailable functions is here: https://data-apis.org/array-api-compat/supported-array-libraries.html#dask. There's already some work on implementing the missing function here, but no issue yet(at least that I can find)
What we can do at the array-api-compat level is rather limited. We could probably add implementations for matrix_rank and pinv since they both are rather thin layers on top of svd; We could maybe add matrix_norm as a wrapper of dask.array.linalg.norm; Stretching it further, det and slogdet should be doable on top of dask.array.linalg.lu.
That said, it seems reasonable to start with dask.array itself. Ideally, these are added to dask in an Array API compatible form. When they are available in dask, and if for some reason they deviate from the spec, then we could happily paper over the differences here in array_api_compat.dask.array.
Would you be interested in working with dask maintainers on these topics @purepani ?
I'll be happy to work with them if possible, but I will be busy on some manuscript writing the next few weeks, so I may not have time to do actual work for it in the immediate future. I already have made an issue here, and happy to do some light work and discussion until I'm able to be more active in July
Dask can easily wrap any numpy function with da.map_blocks. That however requires knowing exactly along which axes the function is not embarassingly parallel, and rechunk to a single chunk beforehand.
array-api-compat is already doing exactly that with sort and argsort; it would be straightforward to extend it to everything else that's missing.
SciPy goes one step further for cython functions that lack documentation about along which axes they reduce; for such functionos it blindly rechunks the array to a single chunk along all axes before feeding the function to Dask. This is currently in scipy.cluster.hierarchy, encapsulated by xpx.lazy_apply calls.
This allows for functional support but at the cost of slowness and KilledWorker failures if your data doesn't fit on a single worker. It was deemed to be a tolerable clutch until Dask implements native support.
Is it reasonable to use the pattern in sort for everything in the linalg extension, if I'm interpreting that correctly?
Comparing what Dask implements natively, https://docs.dask.org/en/latest/array-api.html, with the Array API extension, https://data-apis.org/array-api/2024.12/extensions/linear_algebra_functions.html, here are the missing functions
| function | assessment |
|---|---|
| cross(x1, x2, /, *, axis=-1) | |
| det(x, /) | base off linalg.lu() |
| diagonal(x, /, *, offset=0) | dask.array.diagonal |
| eigh(x, /) | |
| eigvalsh(x, /) | |
| matmul(x1, x2, /) | dask.array.matmul |
| matrix_norm | base off linalg.norm |
| matrix_power(x, n, /) | |
| matrix_rank | base off linalg.svd |
| matrix_transpose(x, /) | base off dask.array.transpose |
| outer(x1, x2, /) | dask.array.outer |
| pinv(x, /, *, rtol=None) | base off linalg.svd |
| slogdet(x, /) | base off linalg.lu |
| svdvals(x, /) | base off linalg.svd |
| tensordot | dask.array.tensordot |
| trace | dask.array.trace |
| vecdot(x1, x2, /, *, axis=-1) | |
| vector_norm | base off linalg.norm |
There seems to be three kinds of functions:
- those with a direct replacement in the main namespace (
tensordot,matmul,outer,trace) : these just need an alias, potentially a small bridge for array API compatibility. - those which can be implemented using other functions from
dask.array.linalg(matrix_rank,slogdet) - those without a native Dask support (eigenvalue solvers)
Of these,
- Can incubate in array-api-compat first, and get submitted to Dask once incubated.
- Should preferably start in dask itself, move to array-api-compat.dask if hard pressed.
- Should definitely live in dask itself.
My reasoning for 2. and 3. is that Dask devs are way better equipped for maintaining dask-specific details, while here we have essentially a bus factor of 1 on @crusaderky. For things like eigenvalue solvers, we definitely should not add something which silently materializes the array. I've no idea if there's a way to do a properly distributed eigenvalue solver in a generic way. If dask devs know how or are happy with a partial solution, great, we'll follow of course.
Is it reasonable to use the pattern in
sortfor everything in the linalg extension, if I'm interpreting that correctly?
Yes, as a temporary crutch.
For things like eigenvalue solvers, we definitely should not add something which silently materializes the array.
The crutch I proposed with da.map_blocks doesn't silently materialize. It silenlty rechunks.
In that case, I can quickly implement the map_blocks crutch for all the missing functions for categories 2. and 3 from @ev-br's list. later today, and then work with dask to slowly replace them.
I'll also implement the ones that are simple in category 1. I have a paper to write, so it'll be quite late, but I'll try and get to it today.