jaxopt icon indicating copy to clipboard operation
jaxopt copied to clipboard

Unnecessary recompilation of _while_loop_lax

Open hrdl-github opened this issue 1 year ago • 8 comments

ed1febef279de5e1120af70dff04261b61175383 claims that jitting _while_loop_lax is redundant. However, this change also seems to prevent the result from being cached, causing recompilation of loops like https://github.com/google/jaxopt/blob/58bac0ac375732dce358cf85583ae7fe3632b8cf/jaxopt/_src/base.py#L314 Reverting ed1febef279de5e1120af70dff04261b61175383 drastically reduces compilation times for my use case, so it probably makes sense to address this cache issue.

hrdl-github avatar Jan 03 '24 22:01 hrdl-github

CC @froystig

mblondel avatar Jan 03 '24 23:01 mblondel

This is mostly relevant when reusing the same solver (I've been using BFGS), as reinstantiation creates new references of _cond_fun and _body_fun, which are static arguments in https://github.com/google/jaxopt/blob/58bac0ac375732dce358cf85583ae7fe3632b8cf/jaxopt/_src/loop.py#L80, I think

hrdl-github avatar Jan 04 '24 00:01 hrdl-github

Do you have a minimal example that reproduces the slowdown, and would you mind posting it here if so?

froystig avatar Jan 04 '24 02:01 froystig

import time, jaxopt, jax, jax.numpy as jnp

def rosenbrock(x):
    return jnp.sum(100. * jnp.diff(x) ** 2 + (1. - x[:-1]) ** 2)

solver = jaxopt.BFGS(rosenbrock)
x0 = jnp.zeros(2)

_time = time.time()
sol = solver.run(x0)
_time = time.time() - _time
print(f'Total {_time} s')

jax.config.update('jax_log_compiles', True)

_time = time.time()
sol2 = solver.run(x0)
_time = time.time() - _time
print(f'Total2 {_time} s')

With jax.jit:

Total 1.3411040306091309 s
[1. 1.]
Total2 0.007392406463623047 s

Original library (jax 0.4.23, jaxopt 0.8.2):

Total 1.3537604808807373 s
Finished tracing + transforming while for pjit in 0.0003256797790527344 sec
Compiling while for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(float32[2]), ShapedArray(int32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[2]), ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[2,2]), ShapedArray(int32[]), ShapedArray(int32[]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
Finished jaxpr to MLIR module conversion jit(while) in 0.1349470615386963 sec
Finished XLA compilation of jit(while) in 0.3022458553314209 sec
[1. 1.]
Total2 0.45375514030456543 s

hrdl-github avatar Jan 04 '24 17:01 hrdl-github

More generally, what do you think about jaxopt caching all solvers, so recompilation would be reduced automatically when not using nested functions?

hrdl-github avatar Jan 09 '24 14:01 hrdl-github

what do you propose more concretely?

mblondel avatar Jan 09 '24 15:01 mblondel

At the moment I essentially use jaxopt with https://github.com/google/jaxopt/commit/ed1febef279de5e1120af70dff04261b61175383 reverted and make sure that I cache my solvers:

@lru_cache
def get_solver(solver, *args, **kwargs):
    return solver(*args, **kwargs)

My suggestion would be to

  1. Implement caching of _while_loop_lax cleanly without relying on jax.jit -- I haven't dug deep enough into jax yet to know the best way to do this, and
  2. Make it easy for the user to reuse solvers or at least document that reusing solvers will reduce / avoid recompilation to benefit from this change. Another topic would be advising against nested functions, which I've seen in a lot of non-official examples.

hrdl-github avatar Jan 09 '24 15:01 hrdl-github

Would it be possible to detect that you're inside a jitted context rather than accepting the Boolean jit parameter? That way, only the user would ever call jit, and would totally control caching and compilation.

NeilGirdhar avatar Jan 11 '24 11:01 NeilGirdhar