skglm
skglm copied to clipboard
POC - Jax implementation of AndersonCD solver
Follow up of #149
This implements AndersonCD solver using Jax-GPU. it proceeds as follows:
- [x] CD solver using Jax
- [x] Working sets
- [x] Anderson acceleration
- [ ] use autodiff
- [x] benchmarks against CPU AndersonCD
Jax triggers another jit-compilation of functions whenever the function arguments change shape. I open an issue on google/jax and it happens to be an inherent functioning of the xla compiler.
This is a limiting factor with the current design as the heavy-costly functions, gradient
/subdiff_dist
, and cd_epoch
, have inputs, namely grad_ws
, ws
, that change shape along the iterations. Therefore most of the time is wasted on recompiling functions.
To bypass that, I'm thinking of tweaking the design to freeze the arrays' shapes across iterations and hence avoid the recompilation. I'm open to other suggestions.