cvxpylayers icon indicating copy to clipboard operation
cvxpylayers copied to clipboard

Works on adding jax.jit support

Open PTNobel opened this issue 1 month ago • 5 comments

@ZedongPeng I couldn't get the backward pass working with MPAX:

  1. The MPAX solver uses jax.lax.while_loop with dynamic termination for iterative optimization
  2. JAX cannot differentiate through while_loop in reverse mode (this is a JAX limitation, not a bug)

How are we supposed to differentiate MPAX?

PTNobel avatar Dec 11 '25 06:12 PTNobel

Hi @PTNobel. To obtain derivatives through the unrolled iterations, we need to set unroll=True. When jit=True, MPAX under the hood uses jax.lax.scan.

ZedongPeng avatar Dec 12 '25 15:12 ZedongPeng

Sorry, when do I have to set unroll=True? In the solve call?

PTNobel avatar Dec 12 '25 23:12 PTNobel

when you defined the solver, say solver = r2HPDHG(eps_abs=1e-4, eps_rel=1e-4, verbose=True, unroll=True). https://github.com/MIT-Lu-Lab/MPAX/blob/ca1c669fea422c2509fae4bc30e1d79e6ca8977c/mpax/rapdhg.py#L306 https://github.com/MIT-Lu-Lab/MPAX/blob/ca1c669fea422c2509fae4bc30e1d79e6ca8977c/mpax/r2hpdhg.py#L59

ZedongPeng avatar Dec 16 '25 16:12 ZedongPeng

@ZedongPeng Sorry to bother you again but I tried adding unroll=True (see the commit) and it started trying to allocate a terabyte or two of memory and subsequently crashing. Any idea why? Do I need to relax the tolerances?

PTNobel avatar Dec 17 '25 01:12 PTNobel

I don't think it is related to tolerance. Do you have a sample script or data file handy that I could test with?

ZedongPeng avatar Dec 17 '25 03:12 ZedongPeng

@ZedongPeng All of the tests here failed: https://github.com/cvxpy/cvxpylayers/blob/ptn/jax.jit/tests/test_mpax.py

PTNobel avatar Dec 18 '25 08:12 PTNobel