Works on adding jax.jit support
@ZedongPeng I couldn't get the backward pass working with MPAX:
- The MPAX solver uses jax.lax.while_loop with dynamic termination for iterative optimization
- JAX cannot differentiate through while_loop in reverse mode (this is a JAX limitation, not a bug)
How are we supposed to differentiate MPAX?
Hi @PTNobel. To obtain derivatives through the unrolled iterations, we need to set unroll=True. When jit=True, MPAX under the hood uses jax.lax.scan.
Sorry, when do I have to set unroll=True? In the solve call?
when you defined the solver, say solver = r2HPDHG(eps_abs=1e-4, eps_rel=1e-4, verbose=True, unroll=True).
https://github.com/MIT-Lu-Lab/MPAX/blob/ca1c669fea422c2509fae4bc30e1d79e6ca8977c/mpax/rapdhg.py#L306
https://github.com/MIT-Lu-Lab/MPAX/blob/ca1c669fea422c2509fae4bc30e1d79e6ca8977c/mpax/r2hpdhg.py#L59
@ZedongPeng Sorry to bother you again but I tried adding unroll=True (see the commit) and it started trying to allocate a terabyte or two of memory and subsequently crashing. Any idea why? Do I need to relax the tolerances?
I don't think it is related to tolerance. Do you have a sample script or data file handy that I could test with?
@ZedongPeng All of the tests here failed: https://github.com/cvxpy/cvxpylayers/blob/ptn/jax.jit/tests/test_mpax.py