DESC icon indicating copy to clipboard operation
DESC copied to clipboard

Fix `nan` in reverse mode gradient caused by `rotation_matrix`

Open dpanici opened this issue 11 months ago • 8 comments

Resolves #1456 by adding different method of computing rotation_matrix (from here)

  • [x] Do we want to keep the old rotation_matrix still?
  • [x] I am a bit confused on why this test passes on master even though this seems to be the cause of #1456 ?

dpanici avatar Dec 11 '24 19:12 dpanici

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_midres         |     +1.02 +/- 4.69     | +7.10e-03 +/- 3.26e-02 |  7.02e-01 +/- 3.1e-02  |  6.95e-01 +/- 1.0e-02  |
 test_build_transform_fft_highres        |     +3.91 +/- 4.21     | +3.68e-02 +/- 3.97e-02 |  9.80e-01 +/- 3.7e-02  |  9.43e-01 +/- 1.4e-02  |
 test_equilibrium_init_lowres            |     +4.41 +/- 4.67     | +1.98e-01 +/- 2.10e-01 |  4.69e+00 +/- 2.0e-01  |  4.50e+00 +/- 4.6e-02  |
 test_objective_compile_atf              |     +2.98 +/- 3.81     | +1.85e-01 +/- 2.37e-01 |  6.40e+00 +/- 1.9e-01  |  6.22e+00 +/- 1.4e-01  |
 test_objective_compute_atf              |     +0.18 +/- 12.36    | +1.62e-05 +/- 1.14e-03 |  9.20e-03 +/- 1.1e-03  |  9.19e-03 +/- 1.8e-04  |
 test_objective_jac_atf                  |     +2.29 +/- 2.63     | +3.44e-02 +/- 3.96e-02 |  1.54e+00 +/- 2.6e-02  |  1.51e+00 +/- 3.0e-02  |
 test_perturb_1                          |     +3.49 +/- 1.75     | +5.11e-01 +/- 2.56e-01 |  1.51e+01 +/- 2.4e-01  |  1.46e+01 +/- 1.0e-01  |
 test_proximal_jac_atf                   |     +1.26 +/- 1.19     | +6.68e-02 +/- 6.35e-02 |  5.38e+00 +/- 5.4e-02  |  5.32e+00 +/- 3.4e-02  |
 test_proximal_freeb_compute             |     -1.17 +/- 2.21     | -2.06e-03 +/- 3.91e-03 |  1.74e-01 +/- 2.8e-03  |  1.77e-01 +/- 2.8e-03  |
 test_solve_fixed_iter                   |     +0.72 +/- 3.66     | +2.16e-01 +/- 1.10e+00 |  3.02e+01 +/- 9.0e-01  |  3.00e+01 +/- 6.3e-01  |
 test_objective_compute_ripple           |     -0.07 +/- 1.38     | -1.81e-03 +/- 3.70e-02 |  2.69e+00 +/- 2.8e-02  |  2.69e+00 +/- 2.4e-02  |
 test_objective_grad_ripple              |     +0.23 +/- 2.62     | +1.08e-02 +/- 1.23e-01 |  4.70e+00 +/- 1.1e-01  |  4.69e+00 +/- 4.8e-02  |
 test_build_transform_fft_lowres         |     +0.50 +/- 2.52     | +2.92e-03 +/- 1.47e-02 |  5.88e-01 +/- 1.2e-02  |  5.85e-01 +/- 8.3e-03  |
 test_equilibrium_init_medres            |     +0.79 +/- 2.48     | +3.98e-02 +/- 1.25e-01 |  5.11e+00 +/- 9.7e-02  |  5.07e+00 +/- 7.9e-02  |
 test_equilibrium_init_highres           |     +1.53 +/- 4.14     | +8.80e-02 +/- 2.38e-01 |  5.82e+00 +/- 2.2e-01  |  5.73e+00 +/- 8.7e-02  |
 test_objective_compile_dshape_current   |     +0.64 +/- 1.63     | +2.28e-02 +/- 5.84e-02 |  3.61e+00 +/- 4.5e-02  |  3.58e+00 +/- 3.8e-02  |
 test_objective_compute_dshape_current   |     +1.61 +/- 1.57     | +5.51e-05 +/- 5.35e-05 |  3.47e-03 +/- 4.5e-05  |  3.42e-03 +/- 2.9e-05  |
 test_objective_jac_dshape_current       |     +0.32 +/- 12.41    | +1.05e-04 +/- 4.04e-03 |  3.27e-02 +/- 2.4e-03  |  3.26e-02 +/- 3.2e-03  |
 test_perturb_2                          |     -0.03 +/- 2.25     | -5.41e-03 +/- 4.15e-01 |  1.85e+01 +/- 3.4e-01  |  1.85e+01 +/- 2.4e-01  |
 test_proximal_jac_atf_with_eq_update    |     +0.01 +/- 0.91     | +1.59e-03 +/- 1.17e-01 |  1.29e+01 +/- 8.2e-02  |  1.29e+01 +/- 8.3e-02  |
 test_proximal_freeb_jac                 |     +3.11 +/- 10.41    | +1.45e-01 +/- 4.84e-01 |  4.79e+00 +/- 3.4e-01  |  4.65e+00 +/- 3.5e-01  |
 test_solve_fixed_iter_compiled          |     +1.17 +/- 1.22     | +2.07e-01 +/- 2.16e-01 |  1.79e+01 +/- 1.8e-01  |  1.77e+01 +/- 1.2e-01  |
 test_LinearConstraintProjection_build   |     +1.61 +/- 1.75     | +1.43e-01 +/- 1.55e-01 |  8.97e+00 +/- 7.0e-02  |  8.83e+00 +/- 1.4e-01  |
 test_objective_compute_ripple_spline    |     -0.81 +/- 5.08     | -2.48e-03 +/- 1.55e-02 |  3.03e-01 +/- 4.7e-03  |  3.06e-01 +/- 1.5e-02  |
 test_objective_grad_ripple_spline       |     -0.31 +/- 1.50     | -2.87e-03 +/- 1.39e-02 |  9.26e-01 +/- 1.3e-02  |  9.29e-01 +/- 5.0e-03  |

github-actions[bot] avatar Dec 11 '24 22:12 github-actions[bot]

That existing test passes because the issue isnt with the rotation matrix itself, but in the arccos that is needed to compute the angle

f0uriest avatar Dec 12 '24 14:12 f0uriest

  • safe arccos
  • check for sign of dot product of Zaxis and normal to assign correct sign
  • check for normal near-Z-axis works as well (when almost but not quite aligned with normal)

dpanici avatar Dec 16 '24 19:12 dpanici

Codecov Report

:white_check_mark: All modified and coverable lines are covered by tests. :white_check_mark: Project coverage is 95.78%. Comparing base (7439edc) to head (a597972). :warning: Report is 182 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #1457   +/-   ##
=======================================
  Coverage   95.77%   95.78%           
=======================================
  Files         101      101           
  Lines       26982    26995   +13     
=======================================
+ Hits        25843    25857   +14     
+ Misses       1139     1138    -1     
Files with missing lines Coverage Δ
desc/compute/_curve.py 100.00% <100.00%> (ø)
desc/geometry/curve.py 96.03% <100.00%> (+0.01%) :arrow_up:
desc/utils.py 92.37% <100.00%> (+0.06%) :arrow_up:

... and 2 files with indirect coverage changes

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Feb 04 '25 04:02 codecov[bot]

If the original rotation matrix function is now unused I'm fine removing it

it is used in some of our coils for the rotate method, right now this PR does not change any of the public-facing ways for doing this. I guess I can change just internally how the rotmat is computed

dpanici avatar Feb 05 '25 19:02 dpanici

If the original rotation matrix function is now unused I'm fine removing it

it is used in some of our coils for the rotate method, right now this PR does not change any of the public-facing ways for doing this. I guess I can change just internally how the rotmat is computed, but I am pretty sure I would have to use the same formula as the one implemented now anyways to compute the correct a/b matrix to pass to the new function.

I can just make sure the old function correctly handles reflections and keep it for now.

dpanici avatar Feb 05 '25 19:02 dpanici

This PR now works fine, but seems the new way I implemented is noisier for very nearly-antiparallel, so if we can make safe arccos work or I will try the other way Rory mentioned, that would be nice. But this is probably good enough for now.

from desc.compute.geom_utils import rotation_matrix_vector_vector, rotation_matrix
from desc.utils import safenormalize
import numpy as np
print("\nnew way (vector->vector rotation)")
a = [0,1,0]
b = np.array([0,-1,1e-6])
A=rotation_matrix_vector_vector(a, b)
print(A)
print(np.linalg.det(A))
aa = np.array([1,1,1])
print(A@aa)
print("\nold way (axis-angle rotation)")
axis = np.cross(a, b)
angle = np.arccos(np.dot(a, safenormalize(b)))
A=rotation_matrix(axis,angle)
print(A)
print(np.linalg.det(A))
aa = np.array([1,1,1])
print(A@aa)

yields

new way (vector->vector rotation)
[[ 1.00000000e+00  0.00000000e+00  0.00000000e+00]
 [ 0.00000000e+00 -9.99822215e-01 -1.00000000e-06]
 [ 0.00000000e+00  1.00000000e-06 -9.99822215e-01]]
0.9996444608857149
[ 1.         -0.99982321 -0.99982121]

old way (axis-angle rotation)
[[ 1.00000000e+00  0.00000000e+00  0.00000000e+00]
 [ 0.00000000e+00 -1.00000000e+00 -1.00004445e-06]
 [-0.00000000e+00  1.00004445e-06 -1.00000000e+00]]
1.0
[ 1.       -1.000001 -0.999999]

dpanici avatar Feb 05 '25 22:02 dpanici

Don't know if related but nice trick https://docs.kidger.site/equinox/api/debug/#common-sources-of-nans

YigitElma avatar Apr 07 '25 17:04 YigitElma

Memory benchmark result

|               Test Name                |      %Δ      |    Master (MB)     |      PR (MB)       |    Δ (MB)    |    Time PR (s)     |  Time Master (s)   |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
  test_objective_jac_w7x                 |    5.38 %    |     3.869e+03      |     4.077e+03      |    208.05    |       31.63        |       30.64        |
  test_proximal_jac_w7x_with_eq_update   |   -0.89 %    |     6.758e+03      |     6.698e+03      |    -60.11    |       162.95       |       166.76       |
  test_proximal_freeb_jac                |   -0.73 %    |     1.327e+04      |     1.317e+04      |    -97.46    |       78.43        |       76.13        |
  test_proximal_freeb_jac_blocked        |    0.92 %    |     7.517e+03      |     7.586e+03      |    69.25     |       66.49        |       67.90        |
  test_proximal_freeb_jac_batched        |    0.14 %    |     7.516e+03      |     7.527e+03      |    10.81     |       68.80        |       67.37        |
  test_proximal_jac_ripple               |    0.17 %    |     7.515e+03      |     7.528e+03      |    12.41     |       76.03        |       71.92        |
  test_proximal_jac_ripple_spline        |    0.99 %    |     3.499e+03      |     3.534e+03      |    34.58     |       76.74        |       74.34        |
- test_eq_solve                          |   10.30 %    |     1.947e+03      |     2.148e+03      |    200.59    |       125.27       |       123.17       |

For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.

github-actions[bot] avatar May 27 '25 21:05 github-actions[bot]

@dpanici

YigitElma avatar Jun 03 '25 06:06 YigitElma