DESC icon indicating copy to clipboard operation
DESC copied to clipboard

Option to use float32

Open rahulgaur104 opened this issue 1 year ago • 4 comments

Resolves #1033 . We see that float32 is clearly desirable as it would allow us to fit more objectives on the same GPU.

But it causes half the tests to fail because the accuracy has been reduced. Let's find a way where we can use a lower precision without significantly altering the final equilibrium.

Perhaps, we could change the precision in the more sensitive parts of the code while still reducing overall memory consumption.

Suggestions?

rahulgaur104 avatar Jun 10 '24 21:06 rahulgaur104

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     +7.67 +/- 6.50     | +4.21e-02 +/- 3.57e-02 |  5.91e-01 +/- 2.0e-02  |  5.49e-01 +/- 2.9e-02  |
 test_build_transform_fft_midres         |     +7.79 +/- 5.38     | +4.86e-02 +/- 3.36e-02 |  6.72e-01 +/- 2.1e-02  |  6.24e-01 +/- 2.6e-02  |
 test_build_transform_fft_highres        |     +0.12 +/- 5.35     | +1.19e-03 +/- 5.52e-02 |  1.03e+00 +/- 4.8e-02  |  1.03e+00 +/- 2.8e-02  |
 test_equilibrium_init_lowres            |     -2.88 +/- 7.93     | -1.22e-01 +/- 3.36e-01 |  4.12e+00 +/- 3.1e-01  |  4.24e+00 +/- 1.3e-01  |
 test_equilibrium_init_medres            |     -2.97 +/- 8.98     | -1.41e-01 +/- 4.26e-01 |  4.60e+00 +/- 4.1e-01  |  4.74e+00 +/- 1.2e-01  |
 test_equilibrium_init_highres           |     -4.29 +/- 2.44     | -2.81e-01 +/- 1.60e-01 |  6.27e+00 +/- 1.1e-01  |  6.55e+00 +/- 1.2e-01  |
 test_objective_compile_dshape_current   |     -2.16 +/- 3.43     | -8.92e-02 +/- 1.42e-01 |  4.04e+00 +/- 1.0e-01  |  4.13e+00 +/- 1.0e-01  |
 test_objective_compile_atf              |     +0.61 +/- 3.21     | +4.90e-02 +/- 2.60e-01 |  8.13e+00 +/- 1.8e-01  |  8.09e+00 +/- 1.8e-01  |
 test_objective_compute_dshape_current   |     -0.30 +/- 1.71     | -1.08e-05 +/- 6.08e-05 |  3.54e-03 +/- 3.1e-05  |  3.55e-03 +/- 5.2e-05  |
 test_objective_compute_atf              |     -2.97 +/- 3.19     | -4.63e-04 +/- 4.97e-04 |  1.51e-02 +/- 2.0e-04  |  1.56e-02 +/- 4.5e-04  |
 test_objective_jac_dshape_current       |     +3.48 +/- 6.96     | +1.42e-03 +/- 2.83e-03 |  4.20e-02 +/- 1.7e-03  |  4.06e-02 +/- 2.3e-03  |
 test_objective_jac_atf                  |     -0.58 +/- 2.71     | -1.15e-02 +/- 5.37e-02 |  1.97e+00 +/- 3.4e-02  |  1.98e+00 +/- 4.2e-02  |
 test_perturb_1                          |     +2.55 +/- 3.18     | +3.47e-01 +/- 4.32e-01 |  1.39e+01 +/- 3.4e-01  |  1.36e+01 +/- 2.7e-01  |
 test_perturb_2                          |     +2.26 +/- 4.68     | +4.20e-01 +/- 8.68e-01 |  1.90e+01 +/- 7.2e-01  |  1.86e+01 +/- 4.8e-01  |
 test_proximal_jac_atf                   |     +0.77 +/- 0.71     | +5.69e-02 +/- 5.23e-02 |  7.44e+00 +/- 3.9e-02  |  7.38e+00 +/- 3.4e-02  |
 test_proximal_freeb_compute             |     +1.63 +/- 0.93     | +3.10e-03 +/- 1.76e-03 |  1.93e-01 +/- 1.1e-03  |  1.90e-01 +/- 1.4e-03  |
 test_proximal_freeb_jac                 |     +0.43 +/- 1.57     | +3.21e-02 +/- 1.17e-01 |  7.50e+00 +/- 1.0e-01  |  7.47e+00 +/- 5.7e-02  |

github-actions[bot] avatar Jun 12 '24 16:06 github-actions[bot]

When this is working we should add an assert inside the in32bit function to ensure that the output is correctly 32 bit before upcasting.

f0uriest avatar Jun 13 '24 20:06 f0uriest

This might be just what we need: https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.disable_x64.html#jax.experimental.disable_x64

f0uriest avatar Jul 12 '24 14:07 f0uriest

Mixed Precision Iterative refinement could be useful?

dpanici avatar Mar 05 '25 20:03 dpanici