DESC
DESC copied to clipboard
Take `p_newton` out of inner while loop
Resolves #1078
Some performance improvements for QR decomposition used in optimization which was first introduced in #1050.
- Take the
p_newtoncalculation out of inner while loop, since it is basically calculating the same QR over and over again - ~Use proper QR update procedure for the
falsefunintrust_region_step_exact_qr. That is we already now QR decomposition ofJ=QR, if we stack a diagonal matrixaItoJthen instead of taking the whole QR decomposition again, there is a more clever way of updating the QR.There are methods for updating a QR factorization when you add rows. Suppose we have~
$$ QR = J $$
what we want is
$$ \tilde{Q} \tilde{R} = \begin{pmatrix} J \ \alpha I \end{pmatrix} $$
The QR update procedure can be implemented on a later PR with Householder matrices, but for now, it seems a bit inefficient to implement using JAX since QR is calculated by Fortran package LAPACK on Scipy and Jax, our custom QR'ish thing will be slow.
I have used Givens rotations for zeroing the elements but maybe Householder reflections are better? Maybe try implementing that.
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | +1.92 +/- 5.56 | +1.05e-02 +/- 3.03e-02 | 5.56e-01 +/- 2.5e-02 | 5.46e-01 +/- 1.7e-02 |
test_build_transform_fft_midres | +0.87 +/- 4.25 | +5.54e-03 +/- 2.72e-02 | 6.44e-01 +/- 1.9e-02 | 6.39e-01 +/- 2.0e-02 |
test_build_transform_fft_highres | +1.76 +/- 2.48 | +1.80e-02 +/- 2.54e-02 | 1.04e+00 +/- 1.9e-02 | 1.02e+00 +/- 1.7e-02 |
test_equilibrium_init_lowres | -0.45 +/- 5.89 | -1.88e-02 +/- 2.43e-01 | 4.11e+00 +/- 1.2e-01 | 4.13e+00 +/- 2.1e-01 |
test_equilibrium_init_medres | -0.85 +/- 4.26 | -3.94e-02 +/- 1.97e-01 | 4.58e+00 +/- 1.4e-01 | 4.62e+00 +/- 1.4e-01 |
test_equilibrium_init_highres | +0.16 +/- 4.19 | +9.31e-03 +/- 2.49e-01 | 5.96e+00 +/- 1.5e-01 | 5.96e+00 +/- 2.0e-01 |
test_objective_compile_dshape_current | +1.86 +/- 2.05 | +7.26e-02 +/- 8.01e-02 | 3.98e+00 +/- 7.5e-02 | 3.91e+00 +/- 2.9e-02 |
test_objective_compile_atf | +2.32 +/- 1.34 | +1.97e-01 +/- 1.13e-01 | 8.66e+00 +/- 9.8e-02 | 8.47e+00 +/- 5.6e-02 |
test_objective_compute_dshape_current | +1.54 +/- 5.01 | +1.94e-05 +/- 6.29e-05 | 1.27e-03 +/- 5.1e-05 | 1.25e-03 +/- 3.7e-05 |
test_objective_compute_atf | +6.33 +/- 6.47 | +2.73e-04 +/- 2.79e-04 | 4.58e-03 +/- 2.3e-04 | 4.31e-03 +/- 1.5e-04 |
test_objective_jac_dshape_current | +0.16 +/- 7.26 | +6.21e-05 +/- 2.86e-03 | 3.94e-02 +/- 1.5e-03 | 3.93e-02 +/- 2.4e-03 |
test_objective_jac_atf | +2.61 +/- 3.05 | +4.93e-02 +/- 5.77e-02 | 1.94e+00 +/- 3.8e-02 | 1.89e+00 +/- 4.3e-02 |
test_perturb_1 | +3.80 +/- 1.90 | +5.35e-01 +/- 2.68e-01 | 1.46e+01 +/- 1.7e-01 | 1.41e+01 +/- 2.1e-01 |
test_perturb_2 | +4.14 +/- 1.76 | +7.98e-01 +/- 3.39e-01 | 2.01e+01 +/- 2.7e-01 | 1.93e+01 +/- 2.1e-01 |
test_proximal_jac_atf | +0.76 +/- 1.02 | +6.20e-02 +/- 8.29e-02 | 8.18e+00 +/- 6.6e-02 | 8.11e+00 +/- 5.0e-02 |
test_proximal_freeb_compute | +1.69 +/- 1.18 | +3.04e-03 +/- 2.13e-03 | 1.83e-01 +/- 1.8e-03 | 1.80e-01 +/- 1.2e-03 |
test_proximal_freeb_jac | +0.01 +/- 1.64 | +9.30e-04 +/- 1.22e-01 | 7.44e+00 +/- 6.8e-02 | 7.44e+00 +/- 1.0e-01 |
test_solve_fixed_iter | -1.98 +/- 16.14 | -3.64e-01 +/- 2.97e+00 | 1.81e+01 +/- 2.1e+00 | 1.84e+01 +/- 2.1e+00 |
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) | | -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- | test_build_transform_fft_lowres | -2.08 +/- 9.64 | -1.10e-02 +/- 5.10e-02 | 5.18e-01 +/- 4.1e-02 | 5.29e-01 +/- 3.1e-02 | test_build_transform_fft_midres | -1.19 +/- 5.78 | -7.22e-03 +/- 3.52e-02 | 6.02e-01 +/- 2.9e-02 | 6.09e-01 +/- 2.0e-02 | test_build_transform_fft_highres | -2.08 +/- 3.18 | -2.10e-02 +/- 3.21e-02 | 9.89e-01 +/- 1.3e-02 | 1.01e+00 +/- 3.0e-02 | test_equilibrium_init_lowres | -2.47 +/- 5.04 | -9.46e-02 +/- 1.93e-01 | 3.73e+00 +/- 1.3e-01 | 3.82e+00 +/- 1.4e-01 | test_equilibrium_init_medres | -1.82 +/- 4.48 | -7.74e-02 +/- 1.91e-01 | 4.18e+00 +/- 9.5e-02 | 4.26e+00 +/- 1.7e-01 | test_equilibrium_init_highres | -1.33 +/- 2.01 | -7.53e-02 +/- 1.14e-01 | 5.59e+00 +/- 5.3e-02 | 5.66e+00 +/- 1.0e-01 | test_objective_compile_dshape_current | -1.55 +/- 3.14 | -6.12e-02 +/- 1.24e-01 | 3.89e+00 +/- 2.5e-02 | 3.95e+00 +/- 1.2e-01 | test_objective_compile_atf | -1.07 +/- 2.79 | -9.00e-02 +/- 2.36e-01 | 8.35e+00 +/- 1.0e-01 | 8.44e+00 +/- 2.1e-01 | test_objective_compute_dshape_current | -1.31 +/- 3.37 | -1.65e-05 +/- 4.25e-05 | 1.25e-03 +/- 2.5e-05 | 1.26e-03 +/- 3.5e-05 | test_objective_compute_atf | -0.65 +/- 4.70 | -2.77e-05 +/- 1.99e-04 | 4.21e-03 +/- 1.1e-04 | 4.24e-03 +/- 1.6e-04 | test_objective_jac_dshape_current | -0.23 +/- 7.42 | -8.59e-05 +/- 2.71e-03 | 3.65e-02 +/- 1.6e-03 | 3.66e-02 +/- 2.2e-03 | test_objective_jac_atf | -0.22 +/- 2.64 | -4.22e-03 +/- 4.95e-02 | 1.87e+00 +/- 2.8e-02 | 1.88e+00 +/- 4.1e-02 | test_perturb_1 | -1.39 +/- 1.07 | -1.96e-01 +/- 1.51e-01 | 1.39e+01 +/- 8.6e-02 | 1.41e+01 +/- 1.2e-01 | test_perturb_2 | -1.95 +/- 1.07 | -3.73e-01 +/- 2.05e-01 | 1.88e+01 +/- 1.3e-01 | 1.92e+01 +/- 1.6e-01 | test_proximal_jac_atf | +0.76 +/- 0.84 | +5.57e-02 +/- 6.13e-02 | 7.38e+00 +/- 4.8e-02 | 7.33e+00 +/- 3.9e-02 | test_proximal_freeb_compute | -0.43 +/- 1.32 | -7.79e-04 +/- 2.40e-03 | 1.81e-01 +/- 1.4e-03 | 1.82e-01 +/- 1.9e-03 | test_proximal_freeb_jac | -0.32 +/- 1.32 | -2.36e-02 +/- 9.73e-02 | 7.36e+00 +/- 8.7e-02 | 7.38e+00 +/- 4.3e-02 | -test_solve_fixed_iter | +5639.94 +/- 7.00 | +1.04e+03 +/- 1.29e+00 | 1.06e+03 +/- 1.1e+00 | 1.85e+01 +/- 6.4e-01 |
😶🫣🤔
I think the method is correct. The new Q and R matrices are almost the same as the ones found by q,r = jax.scipy.linalg.qr(A_t). The only difference is that Some rows of the our method differs in sign, so instead of Q-Q_our==0 we have |Q|-|Q_our|==0. The double for loop (even with jax) is very slow (actually it was known but I thought it only applies to GPU). I will try to implement Househoulder reflections since they don't require nested for loops but a single one over columns.
Just take p_newton out for this PR. Maybe try Householder later
Codecov Report
Attention: Patch coverage is 87.50000% with 2 lines in your changes missing coverage. Please review.
Project coverage is 95.42%. Comparing base (
13108f6) to head (7f3858d). Report is 1705 commits behind head on master.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| desc/optimize/aug_lagrangian_ls.py | 75.00% | 2 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## master #1165 +/- ##
==========================================
- Coverage 95.43% 95.42% -0.02%
==========================================
Files 87 87
Lines 22313 22321 +8
==========================================
+ Hits 21294 21299 +5
- Misses 1019 1022 +3
| Files with missing lines | Coverage Δ | |
|---|---|---|
| desc/optimize/least_squares.py | 99.33% <100.00%> (+0.03%) |
:arrow_up: |
| desc/optimize/tr_subproblems.py | 99.44% <ø> (-0.02%) |
:arrow_down: |
| desc/optimize/aug_lagrangian_ls.py | 95.67% <75.00%> (-0.85%) |
:arrow_down: |
- Flaky Tests Detection - Detect and resolve failed and flaky tests
Can you add a test?
Can you add a test?
Technically I didn't change any logic. The code coverage is lower because previously the qr part was only in the trust_region_step_exact_subproblem and it was tested (or not, we don't have a test for that I guess) once. Now, the same qr part appears in 2 files (aug_lagrangian_ls.py and least_squares.py). I guess the only way to test these is to construct an optimization problem with tall/wide Jacobian for augmented lagrangian and least squares optimizers.
What is the speed up with this change?
It is hard to quantify. This basically saves us from taking the QR of the same thing multiple times in the inner while loop. For most of the problems inner while loop is iterated once, so no speed up there, but for other problems with multiple iterations there is some speed up.
@ddudt Ok I think the total time saved can be found by (Total nfev - total iteration)*(time a single QR takes) and this is 0 for our benchmark case. But usually when I run optimizations, there are 10 15 more function evaluations than total iterations(which is equivalent to total jacobian iterations)