optax
optax copied to clipboard
Support for CSR format sparse matrix in optimizer?
matrix format:https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.BCSR.html#jax.experimental.sparse.BCSR
some feature in deepspeed: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/csrc/adam/fused_adam_frontend.cpp
I believe that if the matrix format is a pytree (which I think it is), then things should work out of the box?
It would be great if you can check whether things do work out of the box or not.
It would be even more awesome if you can contribute an example on using these sparse matrix formats 😉