DESC
DESC copied to clipboard
Option to use float32
Resolves #1033 . We see that float32 is clearly desirable as it would allow us to fit more objectives on the same GPU.
But it causes half the tests to fail because the accuracy has been reduced. Let's find a way where we can use a lower precision without significantly altering the final equilibrium.
Perhaps, we could change the precision in the more sensitive parts of the code while still reducing overall memory consumption.
Suggestions?
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | +7.67 +/- 6.50 | +4.21e-02 +/- 3.57e-02 | 5.91e-01 +/- 2.0e-02 | 5.49e-01 +/- 2.9e-02 |
test_build_transform_fft_midres | +7.79 +/- 5.38 | +4.86e-02 +/- 3.36e-02 | 6.72e-01 +/- 2.1e-02 | 6.24e-01 +/- 2.6e-02 |
test_build_transform_fft_highres | +0.12 +/- 5.35 | +1.19e-03 +/- 5.52e-02 | 1.03e+00 +/- 4.8e-02 | 1.03e+00 +/- 2.8e-02 |
test_equilibrium_init_lowres | -2.88 +/- 7.93 | -1.22e-01 +/- 3.36e-01 | 4.12e+00 +/- 3.1e-01 | 4.24e+00 +/- 1.3e-01 |
test_equilibrium_init_medres | -2.97 +/- 8.98 | -1.41e-01 +/- 4.26e-01 | 4.60e+00 +/- 4.1e-01 | 4.74e+00 +/- 1.2e-01 |
test_equilibrium_init_highres | -4.29 +/- 2.44 | -2.81e-01 +/- 1.60e-01 | 6.27e+00 +/- 1.1e-01 | 6.55e+00 +/- 1.2e-01 |
test_objective_compile_dshape_current | -2.16 +/- 3.43 | -8.92e-02 +/- 1.42e-01 | 4.04e+00 +/- 1.0e-01 | 4.13e+00 +/- 1.0e-01 |
test_objective_compile_atf | +0.61 +/- 3.21 | +4.90e-02 +/- 2.60e-01 | 8.13e+00 +/- 1.8e-01 | 8.09e+00 +/- 1.8e-01 |
test_objective_compute_dshape_current | -0.30 +/- 1.71 | -1.08e-05 +/- 6.08e-05 | 3.54e-03 +/- 3.1e-05 | 3.55e-03 +/- 5.2e-05 |
test_objective_compute_atf | -2.97 +/- 3.19 | -4.63e-04 +/- 4.97e-04 | 1.51e-02 +/- 2.0e-04 | 1.56e-02 +/- 4.5e-04 |
test_objective_jac_dshape_current | +3.48 +/- 6.96 | +1.42e-03 +/- 2.83e-03 | 4.20e-02 +/- 1.7e-03 | 4.06e-02 +/- 2.3e-03 |
test_objective_jac_atf | -0.58 +/- 2.71 | -1.15e-02 +/- 5.37e-02 | 1.97e+00 +/- 3.4e-02 | 1.98e+00 +/- 4.2e-02 |
test_perturb_1 | +2.55 +/- 3.18 | +3.47e-01 +/- 4.32e-01 | 1.39e+01 +/- 3.4e-01 | 1.36e+01 +/- 2.7e-01 |
test_perturb_2 | +2.26 +/- 4.68 | +4.20e-01 +/- 8.68e-01 | 1.90e+01 +/- 7.2e-01 | 1.86e+01 +/- 4.8e-01 |
test_proximal_jac_atf | +0.77 +/- 0.71 | +5.69e-02 +/- 5.23e-02 | 7.44e+00 +/- 3.9e-02 | 7.38e+00 +/- 3.4e-02 |
test_proximal_freeb_compute | +1.63 +/- 0.93 | +3.10e-03 +/- 1.76e-03 | 1.93e-01 +/- 1.1e-03 | 1.90e-01 +/- 1.4e-03 |
test_proximal_freeb_jac | +0.43 +/- 1.57 | +3.21e-02 +/- 1.17e-01 | 7.50e+00 +/- 1.0e-01 | 7.47e+00 +/- 5.7e-02 |
When this is working we should add an assert inside the in32bit function to ensure that the output is correctly 32 bit before upcasting.
This might be just what we need: https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.disable_x64.html#jax.experimental.disable_x64
Mixed Precision Iterative refinement could be useful?