dask-glm
dask-glm copied to clipboard
Sparse+dense mixed arrays
Previously admm would rechunk the columns to be in a single chunk, and then pass delayed numpy arrays to the local_update function. If the chunks along columns were of different types, like a numpy array and a sparse array, then these would be inefficiently coerced to a single type.
Now we pass a list of numpy arrays to the local_update function. If this list has more than one element then we construct a local dask.array so that operations like dot do the right thing and call two different local dot functions, one for each type.
This currently depends on https://github.com/dask/dask/pull/2272 though I may be able to avoid this dependency.
There is a non-trivial cost to using dask.array within the function given to scipy.optimize
. Graph generation costs appear to be non-trivial. I can reduce these somewhat, but I'm also curious if it is possible to apply f
and fprime
individually to all chunks of the input to local_update
. In this case each chunk correspond to a block of columns. Looking at the methods in families.py it looks like it might be possible to evaluate f
on each block and add them together and to evaluate fprime
on each block and concatenate them.
@moody-marlin is this generally possible?
(please let me know if this description was not clear)
Hmmm I'm not sure how possible this is; it definitely won't be as easy as evaluating f
and fprime
on each block and combining them though; for example,
(( y - X1.dot(beta) - X2.dot(beta)) ** 2).sum()
doesn't split as a simple sum on each chunk.
There might be fancier ways of combining the results, or possibly even altering ADMM to take this into account, but it will require some non-trivial thinking.
There might be fancier ways of combining the results
The fancy way here is already handled by dask.array. I was just hoping to avoid having to recreate graphs every time. I can probably be clever here though. I'll give it a shot
OK, I've pushed a solution that, I think, avoids most graph construction costs. However my algorithm is failing to converge. @moody-marlin if you find yourself with a free 15 minutes can you take a look at def local_update()
and see if I'm doing anything wonky? I suspect that there is something dumb going on. When I print out the result and gradients I find that it doesn't seem to converge (gradients stay large).