DESC icon indicating copy to clipboard operation
DESC copied to clipboard

Take `p_newton` out of inner while loop

Open YigitElma opened this issue 1 year ago • 6 comments

Resolves #1078

Some performance improvements for QR decomposition used in optimization which was first introduced in #1050.

  • Take the p_newton calculation out of inner while loop, since it is basically calculating the same QR over and over again
  • ~Use proper QR update procedure for the falsefun in trust_region_step_exact_qr. That is we already now QR decomposition of J=QR, if we stack a diagonal matrix aI to J then instead of taking the whole QR decomposition again, there is a more clever way of updating the QR.There are methods for updating a QR factorization when you add rows. Suppose we have~

$$ QR = J $$

what we want is

$$ \tilde{Q} \tilde{R} = \begin{pmatrix} J \ \alpha I \end{pmatrix} $$

The QR update procedure can be implemented on a later PR with Householder matrices, but for now, it seems a bit inefficient to implement using JAX since QR is calculated by Fortran package LAPACK on Scipy and Jax, our custom QR'ish thing will be slow.

YigitElma avatar Aug 06 '24 04:08 YigitElma

I have used Givens rotations for zeroing the elements but maybe Householder reflections are better? Maybe try implementing that.

YigitElma avatar Aug 06 '24 16:08 YigitElma

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     +1.92 +/- 5.56     | +1.05e-02 +/- 3.03e-02 |  5.56e-01 +/- 2.5e-02  |  5.46e-01 +/- 1.7e-02  |
 test_build_transform_fft_midres         |     +0.87 +/- 4.25     | +5.54e-03 +/- 2.72e-02 |  6.44e-01 +/- 1.9e-02  |  6.39e-01 +/- 2.0e-02  |
 test_build_transform_fft_highres        |     +1.76 +/- 2.48     | +1.80e-02 +/- 2.54e-02 |  1.04e+00 +/- 1.9e-02  |  1.02e+00 +/- 1.7e-02  |
 test_equilibrium_init_lowres            |     -0.45 +/- 5.89     | -1.88e-02 +/- 2.43e-01 |  4.11e+00 +/- 1.2e-01  |  4.13e+00 +/- 2.1e-01  |
 test_equilibrium_init_medres            |     -0.85 +/- 4.26     | -3.94e-02 +/- 1.97e-01 |  4.58e+00 +/- 1.4e-01  |  4.62e+00 +/- 1.4e-01  |
 test_equilibrium_init_highres           |     +0.16 +/- 4.19     | +9.31e-03 +/- 2.49e-01 |  5.96e+00 +/- 1.5e-01  |  5.96e+00 +/- 2.0e-01  |
 test_objective_compile_dshape_current   |     +1.86 +/- 2.05     | +7.26e-02 +/- 8.01e-02 |  3.98e+00 +/- 7.5e-02  |  3.91e+00 +/- 2.9e-02  |
 test_objective_compile_atf              |     +2.32 +/- 1.34     | +1.97e-01 +/- 1.13e-01 |  8.66e+00 +/- 9.8e-02  |  8.47e+00 +/- 5.6e-02  |
 test_objective_compute_dshape_current   |     +1.54 +/- 5.01     | +1.94e-05 +/- 6.29e-05 |  1.27e-03 +/- 5.1e-05  |  1.25e-03 +/- 3.7e-05  |
 test_objective_compute_atf              |     +6.33 +/- 6.47     | +2.73e-04 +/- 2.79e-04 |  4.58e-03 +/- 2.3e-04  |  4.31e-03 +/- 1.5e-04  |
 test_objective_jac_dshape_current       |     +0.16 +/- 7.26     | +6.21e-05 +/- 2.86e-03 |  3.94e-02 +/- 1.5e-03  |  3.93e-02 +/- 2.4e-03  |
 test_objective_jac_atf                  |     +2.61 +/- 3.05     | +4.93e-02 +/- 5.77e-02 |  1.94e+00 +/- 3.8e-02  |  1.89e+00 +/- 4.3e-02  |
 test_perturb_1                          |     +3.80 +/- 1.90     | +5.35e-01 +/- 2.68e-01 |  1.46e+01 +/- 1.7e-01  |  1.41e+01 +/- 2.1e-01  |
 test_perturb_2                          |     +4.14 +/- 1.76     | +7.98e-01 +/- 3.39e-01 |  2.01e+01 +/- 2.7e-01  |  1.93e+01 +/- 2.1e-01  |
 test_proximal_jac_atf                   |     +0.76 +/- 1.02     | +6.20e-02 +/- 8.29e-02 |  8.18e+00 +/- 6.6e-02  |  8.11e+00 +/- 5.0e-02  |
 test_proximal_freeb_compute             |     +1.69 +/- 1.18     | +3.04e-03 +/- 2.13e-03 |  1.83e-01 +/- 1.8e-03  |  1.80e-01 +/- 1.2e-03  |
 test_proximal_freeb_jac                 |     +0.01 +/- 1.64     | +9.30e-04 +/- 1.22e-01 |  7.44e+00 +/- 6.8e-02  |  7.44e+00 +/- 1.0e-01  |
 test_solve_fixed_iter                   |     -1.98 +/- 16.14    | -3.64e-01 +/- 2.97e+00 |  1.81e+01 +/- 2.1e+00  |  1.84e+01 +/- 2.1e+00  |

github-actions[bot] avatar Aug 06 '24 23:08 github-actions[bot]

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     -2.08 +/- 9.64     | -1.10e-02 +/- 5.10e-02 |  5.18e-01 +/- 4.1e-02  |  5.29e-01 +/- 3.1e-02  |
 test_build_transform_fft_midres         |     -1.19 +/- 5.78     | -7.22e-03 +/- 3.52e-02 |  6.02e-01 +/- 2.9e-02  |  6.09e-01 +/- 2.0e-02  |
 test_build_transform_fft_highres        |     -2.08 +/- 3.18     | -2.10e-02 +/- 3.21e-02 |  9.89e-01 +/- 1.3e-02  |  1.01e+00 +/- 3.0e-02  |
 test_equilibrium_init_lowres            |     -2.47 +/- 5.04     | -9.46e-02 +/- 1.93e-01 |  3.73e+00 +/- 1.3e-01  |  3.82e+00 +/- 1.4e-01  |
 test_equilibrium_init_medres            |     -1.82 +/- 4.48     | -7.74e-02 +/- 1.91e-01 |  4.18e+00 +/- 9.5e-02  |  4.26e+00 +/- 1.7e-01  |
 test_equilibrium_init_highres           |     -1.33 +/- 2.01     | -7.53e-02 +/- 1.14e-01 |  5.59e+00 +/- 5.3e-02  |  5.66e+00 +/- 1.0e-01  |
 test_objective_compile_dshape_current   |     -1.55 +/- 3.14     | -6.12e-02 +/- 1.24e-01 |  3.89e+00 +/- 2.5e-02  |  3.95e+00 +/- 1.2e-01  |
 test_objective_compile_atf              |     -1.07 +/- 2.79     | -9.00e-02 +/- 2.36e-01 |  8.35e+00 +/- 1.0e-01  |  8.44e+00 +/- 2.1e-01  |
 test_objective_compute_dshape_current   |     -1.31 +/- 3.37     | -1.65e-05 +/- 4.25e-05 |  1.25e-03 +/- 2.5e-05  |  1.26e-03 +/- 3.5e-05  |
 test_objective_compute_atf              |     -0.65 +/- 4.70     | -2.77e-05 +/- 1.99e-04 |  4.21e-03 +/- 1.1e-04  |  4.24e-03 +/- 1.6e-04  |
 test_objective_jac_dshape_current       |     -0.23 +/- 7.42     | -8.59e-05 +/- 2.71e-03 |  3.65e-02 +/- 1.6e-03  |  3.66e-02 +/- 2.2e-03  |
 test_objective_jac_atf                  |     -0.22 +/- 2.64     | -4.22e-03 +/- 4.95e-02 |  1.87e+00 +/- 2.8e-02  |  1.88e+00 +/- 4.1e-02  |
 test_perturb_1                          |     -1.39 +/- 1.07     | -1.96e-01 +/- 1.51e-01 |  1.39e+01 +/- 8.6e-02  |  1.41e+01 +/- 1.2e-01  |
 test_perturb_2                          |     -1.95 +/- 1.07     | -3.73e-01 +/- 2.05e-01 |  1.88e+01 +/- 1.3e-01  |  1.92e+01 +/- 1.6e-01  |
 test_proximal_jac_atf                   |     +0.76 +/- 0.84     | +5.57e-02 +/- 6.13e-02 |  7.38e+00 +/- 4.8e-02  |  7.33e+00 +/- 3.9e-02  |
 test_proximal_freeb_compute             |     -0.43 +/- 1.32     | -7.79e-04 +/- 2.40e-03 |  1.81e-01 +/- 1.4e-03  |  1.82e-01 +/- 1.9e-03  |
 test_proximal_freeb_jac                 |     -0.32 +/- 1.32     | -2.36e-02 +/- 9.73e-02 |  7.36e+00 +/- 8.7e-02  |  7.38e+00 +/- 4.3e-02  |
-test_solve_fixed_iter                   |   +5639.94 +/- 7.00    | +1.04e+03 +/- 1.29e+00 |  1.06e+03 +/- 1.1e+00  |  1.85e+01 +/- 6.4e-01  |

😶🫣🤔

YigitElma avatar Aug 07 '24 01:08 YigitElma

I think the method is correct. The new Q and R matrices are almost the same as the ones found by q,r = jax.scipy.linalg.qr(A_t). The only difference is that Some rows of the our method differs in sign, so instead of Q-Q_our==0 we have |Q|-|Q_our|==0. The double for loop (even with jax) is very slow (actually it was known but I thought it only applies to GPU). I will try to implement Househoulder reflections since they don't require nested for loops but a single one over columns.

YigitElma avatar Aug 08 '24 00:08 YigitElma

Just take p_newton out for this PR. Maybe try Householder later

YigitElma avatar Aug 15 '24 05:08 YigitElma

Codecov Report

Attention: Patch coverage is 87.50000% with 2 lines in your changes missing coverage. Please review.

Project coverage is 95.42%. Comparing base (13108f6) to head (7f3858d). Report is 1705 commits behind head on master.

Files with missing lines Patch % Lines
desc/optimize/aug_lagrangian_ls.py 75.00% 2 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1165      +/-   ##
==========================================
- Coverage   95.43%   95.42%   -0.02%     
==========================================
  Files          87       87              
  Lines       22313    22321       +8     
==========================================
+ Hits        21294    21299       +5     
- Misses       1019     1022       +3     
Files with missing lines Coverage Δ
desc/optimize/least_squares.py 99.33% <100.00%> (+0.03%) :arrow_up:
desc/optimize/tr_subproblems.py 99.44% <ø> (-0.02%) :arrow_down:
desc/optimize/aug_lagrangian_ls.py 95.67% <75.00%> (-0.85%) :arrow_down:

... and 3 files with indirect coverage changes

---- 🚨 Try these New Features:

codecov[bot] avatar Aug 15 '24 18:08 codecov[bot]

Can you add a test?

unalmis avatar Aug 20 '24 15:08 unalmis

Can you add a test?

Technically I didn't change any logic. The code coverage is lower because previously the qr part was only in the trust_region_step_exact_subproblem and it was tested (or not, we don't have a test for that I guess) once. Now, the same qr part appears in 2 files (aug_lagrangian_ls.py and least_squares.py). I guess the only way to test these is to construct an optimization problem with tall/wide Jacobian for augmented lagrangian and least squares optimizers.

YigitElma avatar Aug 20 '24 17:08 YigitElma

What is the speed up with this change?

It is hard to quantify. This basically saves us from taking the QR of the same thing multiple times in the inner while loop. For most of the problems inner while loop is iterated once, so no speed up there, but for other problems with multiple iterations there is some speed up.

YigitElma avatar Aug 20 '24 18:08 YigitElma

@ddudt Ok I think the total time saved can be found by (Total nfev - total iteration)*(time a single QR takes) and this is 0 for our benchmark case. But usually when I run optimizations, there are 10 15 more function evaluations than total iterations(which is equivalent to total jacobian iterations)

YigitElma avatar Aug 22 '24 21:08 YigitElma