optax icon indicating copy to clipboard operation
optax copied to clipboard

Support for CSR format sparse matrix in optimizer?

Open MoFHeka opened this issue 1 year ago • 1 comments

matrix format:https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.BCSR.html#jax.experimental.sparse.BCSR

some feature in deepspeed: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/csrc/adam/fused_adam_frontend.cpp

MoFHeka avatar Jun 28 '24 09:06 MoFHeka

I believe that if the matrix format is a pytree (which I think it is), then things should work out of the box?

It would be great if you can check whether things do work out of the box or not.

It would be even more awesome if you can contribute an example on using these sparse matrix formats 😉

fabianp avatar Jun 28 '24 11:06 fabianp