jaxopt icon indicating copy to clipboard operation
jaxopt copied to clipboard

Reducing compilation times

Open vroulet opened this issue 1 year ago • 1 comments

The plan of this PR is to reduce the number of times the objective function is compiled. I started by adding a test in common_tests to see how many times a function is compiled for each solver, see below for the results. I will focus first on a simple solver like gradient descent to see if we can reduce the number of compilations. Then I'll extend to each solver until the aforementioned test can be run to satisfy a single compilation of the objective.

AndersonWrapper                       : objective compiled 3 times
BFGS with zoom linesearch             : objective compiled 8 times
ArmijoSGD                             : objective compiled 5 times
GaussNewton                           : objective compiled 26 times
GradientDescent                       : objective compiled 3 times
LBFGS with zoom linesearch            : objective compiled 8 times
LBFGS with hager-zhang linesearch     : objective compiled 75 times
LBFGS with backtracking linesearch    : objective compiled 5 times
LevenbergMarquardt                    : objective compiled 30 times
NonlinearCG with zoom linesearch      : objective compiled 8 times
PolyakSGD                             : objective compiled 3 times
OptaxSolver                           : objective compiled 3 times
AndersonAcceleration                  : objective compiled 3 times
Broyden with backtracking linesearch  : objective compiled 11 times
Bisection                             : objective compiled 4 times
BlockCoordinateDescent                : objective compiled 5 times
LBFGSB with zoom linesearch           : objective compiled 8 times
ProjectedGradient                     : objective compiled 3 times
ProximalGradient                      : objective compiled 3 times
MirrorDescent                         : objective compiled 2 times
BacktrackingLineSearch                : objective compiled 5 times
HagerZhangLineSearch                  : objective compiled 40 times
ZoomLineSearch                        : objective compiled 6 times

vroulet avatar Aug 31 '23 21:08 vroulet

I'm now printing the type of the input:

  • If the function is jitted:
    • At compilation time, this type must be a <Something>Tracer.
    • Afterwards, nothing will be printed.
  • Otherwise if the function is not jitted,
    • no compilation happens, so the printed type is not <Something>Tracer.

The results are given below. The current implementation is not as bad as previously claimed: for e.g. the OptaxSolver there is only one compilation (during the update). The call to the objective in the init function (at least in OptaxSolver/PolyakSGD but I think in other solvers too) does not seem to incur additional compilations.

AndersonWrapper                       : objective compiled 2 times
BFGS with zoom linesearch             : objective compiled 7 times
ArmijoSGD                             : objective compiled 3 times
GaussNewton                           : objective compiled 25 times
GradientDescent                       : objective compiled 2 times
LBFGS with zoom linesearch            : objective compiled 7 times
LBFGS with hager-zhang linesearch     : objective compiled 38 times
LBFGS with backtracking linesearch    : objective compiled 3 times
LevenbergMarquardt                    : objective compiled 29 times
NonlinearCG with zoom linesearch      : objective compiled 7 times
PolyakSGD                             : objective compiled 1 times
OptaxSolver                           : objective compiled 1 times
AndersonAcceleration                  : objective compiled 2 times
Broyden with backtracking linesearch  : objective compiled 6 times
Bisection                             : objective compiled 2 times
BlockCoordinateDescent                : objective compiled 5 times
LBFGSB with zoom linesearch           : objective compiled 7 times
ProjectedGradient                     : objective compiled 2 times
ProximalGradient                      : objective compiled 2 times
MirrorDescent                         : objective compiled 2 times
BacktrackingLineSearch                : objective compiled 3 times
HagerZhangLineSearch                  : objective compiled 40 times
ZoomLineSearch                        : objective compiled 6 times

vroulet avatar Sep 04 '23 17:09 vroulet