DESC icon indicating copy to clipboard operation
DESC copied to clipboard

Remove jit method of objective, directly compile methods

Open f0uriest opened this issue 1 year ago • 4 comments
trafficstars

Basically, right now we close over self when compiling the methods of ObjectiveFunction. This means that JAX may bake all the attributes of the objective (ie, transforms, profiles, fields, equilibrium etc) into the compiled function which likely both slows down compilation and may lead to extra memory usage.

This changes things so that instead we JIT the method directly, treating self as just another argument. Doing this requires refactoring how the derivatives get handled a bit (they are now only local to their respective functions rather than being created separately, shouldn't be any performance hit since creating the Derivative objects is basically free).

Resolves #957 Resolves #1191

f0uriest avatar Jun 04 '24 20:06 f0uriest

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     +0.61 +/- 10.68    | +3.31e-03 +/- 5.76e-02 |  5.43e-01 +/- 4.0e-02  |  5.39e-01 +/- 4.1e-02  |
 test_build_transform_fft_midres         |     +2.68 +/- 9.01     | +1.64e-02 +/- 5.52e-02 |  6.28e-01 +/- 3.8e-02  |  6.12e-01 +/- 4.0e-02  |
 test_build_transform_fft_highres        |     +3.35 +/- 4.50     | +3.34e-02 +/- 4.49e-02 |  1.03e+00 +/- 3.0e-02  |  9.97e-01 +/- 3.3e-02  |
 test_equilibrium_init_lowres            |     +9.11 +/- 8.22     | +3.51e-01 +/- 3.17e-01 |  4.21e+00 +/- 2.9e-01  |  3.86e+00 +/- 1.3e-01  |
 test_equilibrium_init_medres            |     +2.58 +/- 6.03     | +1.19e-01 +/- 2.79e-01 |  4.75e+00 +/- 2.1e-01  |  4.63e+00 +/- 1.8e-01  |
 test_equilibrium_init_highres           |     +3.75 +/- 4.12     | +2.13e-01 +/- 2.33e-01 |  5.88e+00 +/- 1.7e-01  |  5.67e+00 +/- 1.6e-01  |
 test_objective_compile_dshape_current   |     +4.94 +/- 2.51     | +1.87e-01 +/- 9.50e-02 |  3.98e+00 +/- 6.3e-02  |  3.79e+00 +/- 7.1e-02  |
 test_objective_compile_atf              |     -4.31 +/- 2.42     | -3.62e-01 +/- 2.03e-01 |  8.04e+00 +/- 1.4e-01  |  8.40e+00 +/- 1.5e-01  |
-test_objective_compute_dshape_current   |    +172.30 +/- 4.71    | +2.19e-03 +/- 5.99e-05 |  3.46e-03 +/- 3.7e-05  |  1.27e-03 +/- 4.7e-05  |
-test_objective_compute_atf              |    +134.12 +/- 4.95    | +5.98e-03 +/- 2.21e-04 |  1.04e-02 +/- 1.7e-04  |  4.46e-03 +/- 1.4e-04  |
 test_objective_jac_dshape_current       |     +4.88 +/- 7.31     | +1.90e-03 +/- 2.84e-03 |  4.08e-02 +/- 2.4e-03  |  3.89e-02 +/- 1.5e-03  |
 test_objective_jac_atf                  |     -1.03 +/- 3.34     | -1.98e-02 +/- 6.42e-02 |  1.90e+00 +/- 5.5e-02  |  1.92e+00 +/- 3.3e-02  |
 test_perturb_1                          |     -7.18 +/- 7.10     | -9.60e-01 +/- 9.49e-01 |  1.24e+01 +/- 1.6e-01  |  1.34e+01 +/- 9.4e-01  |
 test_perturb_2                          |     -6.97 +/- 2.94     | -1.30e+00 +/- 5.47e-01 |  1.73e+01 +/- 3.6e-01  |  1.86e+01 +/- 4.1e-01  |
 test_proximal_jac_atf                   |     +0.07 +/- 1.40     | +5.68e-03 +/- 1.14e-01 |  8.11e+00 +/- 8.8e-02  |  8.11e+00 +/- 7.2e-02  |
-test_proximal_freeb_compute             |     +5.08 +/- 1.63     | +8.93e-03 +/- 2.86e-03 |  1.85e-01 +/- 2.1e-03  |  1.76e-01 +/- 2.0e-03  |
 test_proximal_freeb_jac                 |     +2.16 +/- 1.33     | +1.57e-01 +/- 9.65e-02 |  7.44e+00 +/- 8.9e-02  |  7.28e+00 +/- 3.8e-02  |
+test_solve_fixed_iter                   |    -71.70 +/- 17.72    | -1.24e+01 +/- 3.07e+00 |  4.91e+00 +/- 2.1e+00  |  1.73e+01 +/- 2.2e+00  |

github-actions[bot] avatar Jun 04 '24 21:06 github-actions[bot]

@YigitElma @dpanici @kianorr @rahulgaur104 we should profile this change memory-wise

dpanici avatar Jun 25 '24 20:06 dpanici

Codecov Report

Attention: Patch coverage is 93.53234% with 13 lines in your changes missing coverage. Please review.

Project coverage is 95.42%. Comparing base (1c076fc) to head (a719f8a). Report is 1682 commits behind head on master.

Files with missing lines Patch % Lines
desc/objectives/objective_funs.py 91.66% 11 Missing :warning:
desc/objectives/linear_objectives.py 71.42% 2 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1043      +/-   ##
==========================================
- Coverage   95.42%   95.42%   -0.01%     
==========================================
  Files          87       87              
  Lines       22341    22423      +82     
==========================================
+ Hits        21320    21398      +78     
- Misses       1021     1025       +4     
Files with missing lines Coverage Δ
desc/io/optimizable_io.py 86.30% <100.00%> (+0.08%) :arrow_up:
desc/objectives/utils.py 100.00% <100.00%> (ø)
desc/optimize/_constraint_wrappers.py 95.82% <100.00%> (+0.08%) :arrow_up:
desc/objectives/linear_objectives.py 97.04% <71.42%> (+0.01%) :arrow_up:
desc/objectives/objective_funs.py 93.97% <91.66%> (+0.02%) :arrow_up:

... and 5 files with indirect coverage changes

---- 🚨 Try these New Features:

codecov[bot] avatar Aug 15 '24 04:08 codecov[bot]

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     -1.56 +/- 9.77     | -8.12e-03 +/- 5.09e-02 |  5.13e-01 +/- 2.9e-02  |  5.21e-01 +/- 4.2e-02  |
 test_build_transform_fft_midres         |     -0.57 +/- 2.07     | -3.38e-03 +/- 1.22e-02 |  5.89e-01 +/- 9.9e-03  |  5.92e-01 +/- 7.2e-03  |
 test_build_transform_fft_highres        |     +3.10 +/- 3.37     | +3.05e-02 +/- 3.32e-02 |  1.01e+00 +/- 2.1e-02  |  9.84e-01 +/- 2.6e-02  |
 test_equilibrium_init_lowres            |     +3.36 +/- 8.44     | +1.32e-01 +/- 3.32e-01 |  4.07e+00 +/- 1.7e-01  |  3.94e+00 +/- 2.9e-01  |
 test_equilibrium_init_medres            |     -1.05 +/- 4.88     | -4.42e-02 +/- 2.04e-01 |  4.15e+00 +/- 1.0e-01  |  4.19e+00 +/- 1.8e-01  |
 test_equilibrium_init_highres           |     +1.22 +/- 3.98     | +6.80e-02 +/- 2.22e-01 |  5.65e+00 +/- 1.4e-01  |  5.59e+00 +/- 1.7e-01  |
 test_objective_compile_dshape_current   |     +2.97 +/- 4.68     | +1.17e-01 +/- 1.85e-01 |  4.06e+00 +/- 1.7e-01  |  3.95e+00 +/- 6.1e-02  |
 test_objective_compile_atf              |     -5.39 +/- 3.01     | -4.55e-01 +/- 2.54e-01 |  7.98e+00 +/- 7.0e-02  |  8.44e+00 +/- 2.4e-01  |
-test_objective_compute_dshape_current   |    +177.24 +/- 4.92    | +2.22e-03 +/- 6.17e-05 |  3.48e-03 +/- 5.1e-05  |  1.25e-03 +/- 3.4e-05  |
-test_objective_compute_atf              |    +139.12 +/- 4.43    | +5.95e-03 +/- 1.89e-04 |  1.02e-02 +/- 1.1e-04  |  4.27e-03 +/- 1.5e-04  |
 test_objective_jac_dshape_current       |     -0.65 +/- 8.29     | -2.51e-04 +/- 3.19e-03 |  3.82e-02 +/- 1.8e-03  |  3.85e-02 +/- 2.6e-03  |
 test_objective_jac_atf                  |     +1.18 +/- 4.31     | +2.19e-02 +/- 8.03e-02 |  1.88e+00 +/- 3.7e-02  |  1.86e+00 +/- 7.1e-02  |
 test_perturb_1                          |     -9.89 +/- 5.29     | -1.38e+00 +/- 7.40e-01 |  1.26e+01 +/- 3.2e-01  |  1.40e+01 +/- 6.7e-01  |
 test_perturb_2                          |     -3.33 +/- 2.48     | -6.19e-01 +/- 4.61e-01 |  1.80e+01 +/- 4.5e-01  |  1.86e+01 +/- 1.1e-01  |
 test_proximal_jac_atf                   |     +1.05 +/- 1.09     | +7.65e-02 +/- 7.96e-02 |  7.35e+00 +/- 6.6e-02  |  7.28e+00 +/- 4.4e-02  |
-test_proximal_freeb_compute             |     +3.08 +/- 0.75     | +5.44e-03 +/- 1.33e-03 |  1.82e-01 +/- 9.9e-04  |  1.76e-01 +/- 8.9e-04  |
 test_proximal_freeb_jac                 |     -0.08 +/- 0.95     | -5.65e-03 +/- 6.95e-02 |  7.35e+00 +/- 5.6e-02  |  7.35e+00 +/- 4.2e-02  |
+test_solve_fixed_iter                   |    -60.47 +/- 16.46    | -1.08e+01 +/- 2.94e+00 |  7.06e+00 +/- 2.9e+00  |  1.78e+01 +/- 5.7e-01  |

I think the slowdown in the compute benchmark is because it has to flatten/unflatten the ObjectiveFunction pytree at each call. It looks like a big slowdown but in absolute time its only a few ms, so likely worth it to get a big speedup on actual solve/optimization (that said I think 60% faster solve is higher than what I'm seeing in practice, it's usually more like 10-20% faster, so might be unique to the benchmark case)

f0uriest avatar Aug 15 '24 16:08 f0uriest