skglm icon indicating copy to clipboard operation
skglm copied to clipboard

POC - Jax implementation of AndersonCD solver

Open Badr-MOUFAD opened this issue 1 year ago • 1 comments

Follow up of #149

This implements AndersonCD solver using Jax-GPU. it proceeds as follows:

  • [x] CD solver using Jax
  • [x] Working sets
  • [x] Anderson acceleration
  • [ ] use autodiff
  • [x] benchmarks against CPU AndersonCD

Badr-MOUFAD avatar Apr 27 '23 08:04 Badr-MOUFAD

Jax triggers another jit-compilation of functions whenever the function arguments change shape. I open an issue on google/jax and it happens to be an inherent functioning of the xla compiler.

This is a limiting factor with the current design as the heavy-costly functions, gradient/subdiff_dist, and cd_epoch, have inputs, namely grad_ws, ws, that change shape along the iterations. Therefore most of the time is wasted on recompiling functions.

To bypass that, I'm thinking of tweaking the design to freeze the arrays' shapes across iterations and hence avoid the recompilation. I'm open to other suggestions.

Badr-MOUFAD avatar Apr 27 '23 09:04 Badr-MOUFAD