jaxopt
jaxopt copied to clipboard
Unnecessary recompilation of _while_loop_lax
ed1febef279de5e1120af70dff04261b61175383 claims that jitting _while_loop_lax
is redundant. However, this change also seems to prevent the result from being cached, causing recompilation of loops like https://github.com/google/jaxopt/blob/58bac0ac375732dce358cf85583ae7fe3632b8cf/jaxopt/_src/base.py#L314 Reverting ed1febef279de5e1120af70dff04261b61175383 drastically reduces compilation times for my use case, so it probably makes sense to address this cache issue.
CC @froystig
This is mostly relevant when reusing the same solver (I've been using BFGS), as reinstantiation creates new references of _cond_fun
and _body_fun
, which are static arguments in https://github.com/google/jaxopt/blob/58bac0ac375732dce358cf85583ae7fe3632b8cf/jaxopt/_src/loop.py#L80, I think
Do you have a minimal example that reproduces the slowdown, and would you mind posting it here if so?
import time, jaxopt, jax, jax.numpy as jnp
def rosenbrock(x):
return jnp.sum(100. * jnp.diff(x) ** 2 + (1. - x[:-1]) ** 2)
solver = jaxopt.BFGS(rosenbrock)
x0 = jnp.zeros(2)
_time = time.time()
sol = solver.run(x0)
_time = time.time() - _time
print(f'Total {_time} s')
jax.config.update('jax_log_compiles', True)
_time = time.time()
sol2 = solver.run(x0)
_time = time.time() - _time
print(f'Total2 {_time} s')
With jax.jit
:
Total 1.3411040306091309 s
[1. 1.]
Total2 0.007392406463623047 s
Original library (jax 0.4.23, jaxopt 0.8.2):
Total 1.3537604808807373 s
Finished tracing + transforming while for pjit in 0.0003256797790527344 sec
Compiling while for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(float32[2]), ShapedArray(int32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[2]), ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[2,2]), ShapedArray(int32[]), ShapedArray(int32[]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
Finished jaxpr to MLIR module conversion jit(while) in 0.1349470615386963 sec
Finished XLA compilation of jit(while) in 0.3022458553314209 sec
[1. 1.]
Total2 0.45375514030456543 s
More generally, what do you think about jaxopt caching all solvers, so recompilation would be reduced automatically when not using nested functions?
what do you propose more concretely?
At the moment I essentially use jaxopt with https://github.com/google/jaxopt/commit/ed1febef279de5e1120af70dff04261b61175383 reverted and make sure that I cache my solvers:
@lru_cache
def get_solver(solver, *args, **kwargs):
return solver(*args, **kwargs)
My suggestion would be to
- Implement caching of
_while_loop_lax
cleanly without relying onjax.jit
-- I haven't dug deep enough into jax yet to know the best way to do this, and - Make it easy for the user to reuse solvers or at least document that reusing solvers will reduce / avoid recompilation to benefit from this change. Another topic would be advising against nested functions, which I've seen in a lot of non-official examples.
Would it be possible to detect that you're inside a jitted context rather than accepting the Boolean jit
parameter? That way, only the user would ever call jit
, and would totally control caching and compilation.