DESC
DESC copied to clipboard
Remove jit method of objective, directly compile methods
Basically, right now we close over self when compiling the methods of ObjectiveFunction. This means that JAX may bake all the attributes of the objective (ie, transforms, profiles, fields, equilibrium etc) into the compiled function which likely both slows down compilation and may lead to extra memory usage.
This changes things so that instead we JIT the method directly, treating self as just another argument. Doing this requires refactoring how the derivatives get handled a bit (they are now only local to their respective functions rather than being created separately, shouldn't be any performance hit since creating the Derivative objects is basically free).
Resolves #957 Resolves #1191
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | +0.61 +/- 10.68 | +3.31e-03 +/- 5.76e-02 | 5.43e-01 +/- 4.0e-02 | 5.39e-01 +/- 4.1e-02 |
test_build_transform_fft_midres | +2.68 +/- 9.01 | +1.64e-02 +/- 5.52e-02 | 6.28e-01 +/- 3.8e-02 | 6.12e-01 +/- 4.0e-02 |
test_build_transform_fft_highres | +3.35 +/- 4.50 | +3.34e-02 +/- 4.49e-02 | 1.03e+00 +/- 3.0e-02 | 9.97e-01 +/- 3.3e-02 |
test_equilibrium_init_lowres | +9.11 +/- 8.22 | +3.51e-01 +/- 3.17e-01 | 4.21e+00 +/- 2.9e-01 | 3.86e+00 +/- 1.3e-01 |
test_equilibrium_init_medres | +2.58 +/- 6.03 | +1.19e-01 +/- 2.79e-01 | 4.75e+00 +/- 2.1e-01 | 4.63e+00 +/- 1.8e-01 |
test_equilibrium_init_highres | +3.75 +/- 4.12 | +2.13e-01 +/- 2.33e-01 | 5.88e+00 +/- 1.7e-01 | 5.67e+00 +/- 1.6e-01 |
test_objective_compile_dshape_current | +4.94 +/- 2.51 | +1.87e-01 +/- 9.50e-02 | 3.98e+00 +/- 6.3e-02 | 3.79e+00 +/- 7.1e-02 |
test_objective_compile_atf | -4.31 +/- 2.42 | -3.62e-01 +/- 2.03e-01 | 8.04e+00 +/- 1.4e-01 | 8.40e+00 +/- 1.5e-01 |
-test_objective_compute_dshape_current | +172.30 +/- 4.71 | +2.19e-03 +/- 5.99e-05 | 3.46e-03 +/- 3.7e-05 | 1.27e-03 +/- 4.7e-05 |
-test_objective_compute_atf | +134.12 +/- 4.95 | +5.98e-03 +/- 2.21e-04 | 1.04e-02 +/- 1.7e-04 | 4.46e-03 +/- 1.4e-04 |
test_objective_jac_dshape_current | +4.88 +/- 7.31 | +1.90e-03 +/- 2.84e-03 | 4.08e-02 +/- 2.4e-03 | 3.89e-02 +/- 1.5e-03 |
test_objective_jac_atf | -1.03 +/- 3.34 | -1.98e-02 +/- 6.42e-02 | 1.90e+00 +/- 5.5e-02 | 1.92e+00 +/- 3.3e-02 |
test_perturb_1 | -7.18 +/- 7.10 | -9.60e-01 +/- 9.49e-01 | 1.24e+01 +/- 1.6e-01 | 1.34e+01 +/- 9.4e-01 |
test_perturb_2 | -6.97 +/- 2.94 | -1.30e+00 +/- 5.47e-01 | 1.73e+01 +/- 3.6e-01 | 1.86e+01 +/- 4.1e-01 |
test_proximal_jac_atf | +0.07 +/- 1.40 | +5.68e-03 +/- 1.14e-01 | 8.11e+00 +/- 8.8e-02 | 8.11e+00 +/- 7.2e-02 |
-test_proximal_freeb_compute | +5.08 +/- 1.63 | +8.93e-03 +/- 2.86e-03 | 1.85e-01 +/- 2.1e-03 | 1.76e-01 +/- 2.0e-03 |
test_proximal_freeb_jac | +2.16 +/- 1.33 | +1.57e-01 +/- 9.65e-02 | 7.44e+00 +/- 8.9e-02 | 7.28e+00 +/- 3.8e-02 |
+test_solve_fixed_iter | -71.70 +/- 17.72 | -1.24e+01 +/- 3.07e+00 | 4.91e+00 +/- 2.1e+00 | 1.73e+01 +/- 2.2e+00 |
@YigitElma @dpanici @kianorr @rahulgaur104 we should profile this change memory-wise
Codecov Report
Attention: Patch coverage is 93.53234% with 13 lines in your changes missing coverage. Please review.
Project coverage is 95.42%. Comparing base (
1c076fc) to head (a719f8a). Report is 1682 commits behind head on master.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| desc/objectives/objective_funs.py | 91.66% | 11 Missing :warning: |
| desc/objectives/linear_objectives.py | 71.42% | 2 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## master #1043 +/- ##
==========================================
- Coverage 95.42% 95.42% -0.01%
==========================================
Files 87 87
Lines 22341 22423 +82
==========================================
+ Hits 21320 21398 +78
- Misses 1021 1025 +4
| Files with missing lines | Coverage Δ | |
|---|---|---|
| desc/io/optimizable_io.py | 86.30% <100.00%> (+0.08%) |
:arrow_up: |
| desc/objectives/utils.py | 100.00% <100.00%> (ø) |
|
| desc/optimize/_constraint_wrappers.py | 95.82% <100.00%> (+0.08%) |
:arrow_up: |
| desc/objectives/linear_objectives.py | 97.04% <71.42%> (+0.01%) |
:arrow_up: |
| desc/objectives/objective_funs.py | 93.97% <91.66%> (+0.02%) |
:arrow_up: |
- Flaky Tests Detection - Detect and resolve failed and flaky tests
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) | | -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- | test_build_transform_fft_lowres | -1.56 +/- 9.77 | -8.12e-03 +/- 5.09e-02 | 5.13e-01 +/- 2.9e-02 | 5.21e-01 +/- 4.2e-02 | test_build_transform_fft_midres | -0.57 +/- 2.07 | -3.38e-03 +/- 1.22e-02 | 5.89e-01 +/- 9.9e-03 | 5.92e-01 +/- 7.2e-03 | test_build_transform_fft_highres | +3.10 +/- 3.37 | +3.05e-02 +/- 3.32e-02 | 1.01e+00 +/- 2.1e-02 | 9.84e-01 +/- 2.6e-02 | test_equilibrium_init_lowres | +3.36 +/- 8.44 | +1.32e-01 +/- 3.32e-01 | 4.07e+00 +/- 1.7e-01 | 3.94e+00 +/- 2.9e-01 | test_equilibrium_init_medres | -1.05 +/- 4.88 | -4.42e-02 +/- 2.04e-01 | 4.15e+00 +/- 1.0e-01 | 4.19e+00 +/- 1.8e-01 | test_equilibrium_init_highres | +1.22 +/- 3.98 | +6.80e-02 +/- 2.22e-01 | 5.65e+00 +/- 1.4e-01 | 5.59e+00 +/- 1.7e-01 | test_objective_compile_dshape_current | +2.97 +/- 4.68 | +1.17e-01 +/- 1.85e-01 | 4.06e+00 +/- 1.7e-01 | 3.95e+00 +/- 6.1e-02 | test_objective_compile_atf | -5.39 +/- 3.01 | -4.55e-01 +/- 2.54e-01 | 7.98e+00 +/- 7.0e-02 | 8.44e+00 +/- 2.4e-01 | -test_objective_compute_dshape_current | +177.24 +/- 4.92 | +2.22e-03 +/- 6.17e-05 | 3.48e-03 +/- 5.1e-05 | 1.25e-03 +/- 3.4e-05 | -test_objective_compute_atf | +139.12 +/- 4.43 | +5.95e-03 +/- 1.89e-04 | 1.02e-02 +/- 1.1e-04 | 4.27e-03 +/- 1.5e-04 | test_objective_jac_dshape_current | -0.65 +/- 8.29 | -2.51e-04 +/- 3.19e-03 | 3.82e-02 +/- 1.8e-03 | 3.85e-02 +/- 2.6e-03 | test_objective_jac_atf | +1.18 +/- 4.31 | +2.19e-02 +/- 8.03e-02 | 1.88e+00 +/- 3.7e-02 | 1.86e+00 +/- 7.1e-02 | test_perturb_1 | -9.89 +/- 5.29 | -1.38e+00 +/- 7.40e-01 | 1.26e+01 +/- 3.2e-01 | 1.40e+01 +/- 6.7e-01 | test_perturb_2 | -3.33 +/- 2.48 | -6.19e-01 +/- 4.61e-01 | 1.80e+01 +/- 4.5e-01 | 1.86e+01 +/- 1.1e-01 | test_proximal_jac_atf | +1.05 +/- 1.09 | +7.65e-02 +/- 7.96e-02 | 7.35e+00 +/- 6.6e-02 | 7.28e+00 +/- 4.4e-02 | -test_proximal_freeb_compute | +3.08 +/- 0.75 | +5.44e-03 +/- 1.33e-03 | 1.82e-01 +/- 9.9e-04 | 1.76e-01 +/- 8.9e-04 | test_proximal_freeb_jac | -0.08 +/- 0.95 | -5.65e-03 +/- 6.95e-02 | 7.35e+00 +/- 5.6e-02 | 7.35e+00 +/- 4.2e-02 | +test_solve_fixed_iter | -60.47 +/- 16.46 | -1.08e+01 +/- 2.94e+00 | 7.06e+00 +/- 2.9e+00 | 1.78e+01 +/- 5.7e-01 |
I think the slowdown in the compute benchmark is because it has to flatten/unflatten the ObjectiveFunction pytree at each call. It looks like a big slowdown but in absolute time its only a few ms, so likely worth it to get a big speedup on actual solve/optimization (that said I think 60% faster solve is higher than what I'm seeing in practice, it's usually more like 10-20% faster, so might be unique to the benchmark case)