DESC
DESC copied to clipboard
Reduce storage of big arrays in memory during optimization
Reuses the J variable in optimizers to reduce the amount of Jacobian-sized matrices held in memory. This can lead to tangible memory decreases (at L=M=N=16 fixed bdry solve with default grids, J is 4GB) which can be observed after the significant memory jump after the first Jacobian calculation.
- Updates docs to give better memory usage insight
- Fixes some benchmark run conditions. This should make the memory profiler run with fork pull requests, also run benchmark when
run_benchmarkslabel is added.
Some vram profiling on my laptop GPU with 12GB memory. The used script:
import sys
import os
sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../"))
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
from desc import set_device
set_device("gpu")
from desc.backend import print_backend_info
from desc.examples import get
from desc.objectives import ObjectiveFunction, ForceBalance
print_backend_info()
N = 14
eq = get("precise_QA")
eq.change_resolution(L=N, M=N, N=N, L_grid=2 * N, M_grid=2 * N, N_grid=2 * N)
eq.resolution_summary()
eq.set_initial_guess()
obj = ObjectiveFunction(ForceBalance(eq), jac_chunk_size=500, deriv_mode="batched")
obj.build()
print(f"Objective function deriv mode: {obj._deriv_mode}")
print(f"Objective function chunk size: {obj._jac_chunk_size}")
# Jacobian size is 4375x25650, which is ~0.8 GB
eq.solve(
objective=obj,
constraints=None,
optimizer="lsq-exact",
ftol=1e-4,
xtol=1e-6,
gtol=1e-6,
maxiter=5,
x_scale="auto",
verbose=2,
copy=False,
)
Master
PR
For clarity, I deleted the previous result with only J=J*d trick.
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | +4.12 +/- 4.44 | +2.39e-02 +/- 2.58e-02 | 6.04e-01 +/- 2.1e-02 | 5.80e-01 +/- 1.5e-02 |
test_equilibrium_init_medres | +3.59 +/- 5.19 | +1.66e-01 +/- 2.40e-01 | 4.78e+00 +/- 1.9e-01 | 4.62e+00 +/- 1.5e-01 |
test_equilibrium_init_highres | +2.21 +/- 3.42 | +1.15e-01 +/- 1.78e-01 | 5.30e+00 +/- 1.4e-01 | 5.19e+00 +/- 1.1e-01 |
test_objective_compile_dshape_current | +1.52 +/- 1.00 | +5.49e-02 +/- 3.62e-02 | 3.67e+00 +/- 2.8e-02 | 3.61e+00 +/- 2.3e-02 |
test_objective_compute_dshape_current | -0.30 +/- 2.54 | -1.05e-05 +/- 8.79e-05 | 3.45e-03 +/- 7.2e-05 | 3.46e-03 +/- 5.1e-05 |
test_objective_jac_dshape_current | -0.43 +/- 14.99 | -1.34e-04 +/- 4.68e-03 | 3.11e-02 +/- 3.4e-03 | 3.12e-02 +/- 3.2e-03 |
test_perturb_2 | +1.54 +/- 3.37 | +2.79e-01 +/- 6.13e-01 | 1.84e+01 +/- 5.8e-01 | 1.82e+01 +/- 1.9e-01 |
test_proximal_jac_atf_with_eq_update | +0.40 +/- 0.75 | +6.04e-02 +/- 1.12e-01 | 1.51e+01 +/- 9.0e-02 | 1.51e+01 +/- 6.7e-02 |
test_proximal_freeb_jac | +2.92 +/- 9.20 | +1.43e-01 +/- 4.51e-01 | 5.04e+00 +/- 3.1e-01 | 4.90e+00 +/- 3.2e-01 |
test_solve_fixed_iter_compiled | -1.30 +/- 2.53 | -2.37e-01 +/- 4.61e-01 | 1.80e+01 +/- 1.2e-01 | 1.82e+01 +/- 4.5e-01 |
test_LinearConstraintProjection_build | -0.67 +/- 2.68 | -5.86e-02 +/- 2.34e-01 | 8.69e+00 +/- 2.2e-01 | 8.75e+00 +/- 7.9e-02 |
test_objective_compute_ripple_spline | -1.90 +/- 5.08 | -6.22e-03 +/- 1.66e-02 | 3.22e-01 +/- 9.1e-03 | 3.28e-01 +/- 1.4e-02 |
test_objective_grad_ripple_spline | -0.44 +/- 3.85 | -5.52e-03 +/- 4.79e-02 | 1.24e+00 +/- 3.9e-02 | 1.25e+00 +/- 2.7e-02 |
test_build_transform_fft_midres | +2.40 +/- 4.85 | +1.69e-02 +/- 3.42e-02 | 7.21e-01 +/- 2.1e-02 | 7.04e-01 +/- 2.7e-02 |
test_build_transform_fft_highres | +0.67 +/- 4.47 | +6.52e-03 +/- 4.38e-02 | 9.86e-01 +/- 4.0e-02 | 9.80e-01 +/- 1.8e-02 |
test_equilibrium_init_lowres | +3.83 +/- 2.48 | +1.57e-01 +/- 1.01e-01 | 4.24e+00 +/- 8.8e-02 | 4.08e+00 +/- 5.0e-02 |
test_objective_compile_atf | -1.22 +/- 4.60 | -7.79e-02 +/- 2.94e-01 | 6.32e+00 +/- 1.5e-01 | 6.40e+00 +/- 2.5e-01 |
test_objective_compute_atf | +6.08 +/- 14.24 | +5.13e-04 +/- 1.20e-03 | 8.95e-03 +/- 1.2e-03 | 8.44e-03 +/- 1.4e-04 |
test_objective_jac_atf | +2.44 +/- 3.44 | +3.74e-02 +/- 5.27e-02 | 1.57e+00 +/- 4.0e-02 | 1.53e+00 +/- 3.4e-02 |
test_perturb_1 | +3.21 +/- 5.48 | +4.56e-01 +/- 7.77e-01 | 1.46e+01 +/- 6.9e-01 | 1.42e+01 +/- 3.6e-01 |
test_proximal_jac_atf | -1.38 +/- 3.07 | -1.06e-01 +/- 2.37e-01 | 7.59e+00 +/- 1.1e-01 | 7.70e+00 +/- 2.1e-01 |
test_proximal_freeb_compute | +0.59 +/- 3.42 | +1.05e-03 +/- 6.12e-03 | 1.80e-01 +/- 5.3e-03 | 1.79e-01 +/- 3.0e-03 |
test_solve_fixed_iter | -1.10 +/- 3.52 | -3.33e-01 +/- 1.07e+00 | 3.01e+01 +/- 3.7e-01 | 3.04e+01 +/- 1.0e+00 |
test_objective_compute_ripple | +0.09 +/- 1.23 | +2.36e-03 +/- 3.37e-02 | 2.75e+00 +/- 2.7e-02 | 2.75e+00 +/- 2.1e-02 |
test_objective_grad_ripple | +1.07 +/- 2.49 | +5.36e-02 +/- 1.24e-01 | 5.05e+00 +/- 5.5e-02 | 4.99e+00 +/- 1.1e-01 |
Codecov Report
:white_check_mark: All modified and coverable lines are covered by tests.
:white_check_mark: Project coverage is 95.79%. Comparing base (f580c7a) to head (7a461e2).
:warning: Report is 273 commits behind head on master.
Additional details and impacted files
@@ Coverage Diff @@
## master #1688 +/- ##
==========================================
+ Coverage 95.67% 95.79% +0.12%
==========================================
Files 101 101
Lines 26731 26753 +22
==========================================
+ Hits 25575 25629 +54
+ Misses 1156 1124 -32
| Files with missing lines | Coverage Δ | |
|---|---|---|
| desc/optimize/aug_lagrangian.py | 97.05% <100.00%> (+0.07%) |
:arrow_up: |
| desc/optimize/aug_lagrangian_ls.py | 95.83% <100.00%> (+0.09%) |
:arrow_up: |
| desc/optimize/fmin_scalar.py | 98.18% <100.00%> (+0.06%) |
:arrow_up: |
| desc/optimize/least_squares.py | 99.36% <100.00%> (+0.02%) |
:arrow_up: |
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
To be clear though, this doesn't resolve the memory leak issue #1686 but prevents storing big arrays at the same time.
Yeah I don't see how this actually fixes the underlying issue in #1686. From the plots above it looks like this doesn't get rid of those spikes which are likely the cause of the OOM after a few iterations, we should figure out what's actually causing that.
Also the plots only show a reduction of ~500 MB, shouldn't it be a lot more?
@f0uriest The last plot is with the most recent changes, and there, the memory decrease is ~1.8GB. The peaks are the Jacobian evaluations, and those can be flattened out by jac_chunk_size=1 (for my test case it was 500).
I think #1686 was at the edge of the memory limit and some small leak like 20-40 MB according to @dpanici's further analysis caused the OOM. We realized these changes while looking at that issue but this PR is not for solving that issue.
Do it in place garbage collection manually in btwn steps
check for fmintr as well
I shared some profiling https://github.com/PlasmaControl/DESC/issues/1686#issuecomment-2814372946 https://github.com/PlasmaControl/DESC/issues/1686#issuecomment-2814409983
I ran the profiling again before and after Garbage collection, it looks like during the optimization there is no change (the difference is almost constant with memory difference at time t=0). And, the at[].set() doesn't change anything, I think that is never in place (except jit), see https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html.
It looks like for lower resolutions garbage collection has some speed effect, maybe we can just run it once before optimization?
Note: I should mention again, this is done on my personal laptop. Although GPU profiling is pretty isolated (normal apps don't use NVidia Vram), CPU memory is exposed to other apps.
Yea maybe we don't gc if it seems to have negligible effect. The in-place stuff if it does not affect speed, we can keep in and then when #1669 is eventually done, we will actually see the benefits then.
Memory benchmark result
| Test Name | %Δ | Master (MB) | PR (MB) | Δ (MB) | Time PR (s) | Time Master (s) |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
test_objective_jac_w7x | 8.31 % | 3.762e+03 | 4.075e+03 | 312.57 | 29.94 | 28.52 |
test_proximal_jac_w7x_with_eq_update | -0.97 % | 6.925e+03 | 6.857e+03 | -67.33 | 172.99 | 170.98 |
test_proximal_freeb_jac | 0.06 % | 1.317e+04 | 1.318e+04 | 7.62 | 71.69 | 74.39 |
test_proximal_freeb_jac_blocked | -0.41 % | 1.320e+04 | 1.314e+04 | -54.11 | 70.18 | 69.37 |
test_proximal_freeb_jac_batched | -0.97 % | 7.558e+03 | 7.485e+03 | -73.59 | 107.44 | 106.93 |
test_proximal_jac_ripple | 0.22 % | 7.286e+03 | 7.303e+03 | 16.30 | 68.49 | 69.02 |
test_proximal_jac_ripple_spline | -0.17 % | 3.826e+03 | 3.820e+03 | -6.64 | 75.41 | 75.93 |
+ test_eq_solve | -11.84 % | 2.249e+03 | 1.983e+03 | -266.44 | 122.44 | 122.35 |
For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.
Memory benchmark result
| Test Name | %Δ | Master (MB) | PR (MB) | Δ (MB) | Time PR (s) | Time Master (s) | | -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ | test_objective_jac_w7x | -0.15 % | 3.945e+03 | 3.939e+03 | -6.04 | 31.84 | 28.83 | test_proximal_jac_w7x_with_eq_update | -0.34 % | 6.902e+03 | 6.879e+03 | -23.60 | 174.22 | 172.89 | test_proximal_freeb_jac | 0.02 % | 1.317e+04 | 1.318e+04 | 3.01 | 71.00 | 70.69 | test_proximal_freeb_jac_blocked | 0.54 % | 1.319e+04 | 1.327e+04 | 71.57 | 69.99 | 71.58 | test_proximal_freeb_jac_batched | -2.12 % | 7.644e+03 | 7.482e+03 | -161.99 | 109.36 | 108.75 | test_proximal_jac_ripple | -0.07 % | 7.328e+03 | 7.323e+03 | -5.43 | 69.62 | 69.91 | test_proximal_jac_ripple_spline | 2.08 % | 3.831e+03 | 3.911e+03 | 79.84 | 77.24 | 77.45 |For the memory plots, go to the summary of
Memory Benchmarksworkflow and download the artifact.
These benchmarks don't have any optimization, so this shouldn't have any effect. If you want, I can add an optimization test with only 1 step.
Memory benchmark result
| Test Name | %Δ | Master (MB) | PR (MB) | Δ (MB) | Time PR (s) | Time Master (s) | | -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ | test_objective_jac_w7x | -0.15 % | 3.945e+03 | 3.939e+03 | -6.04 | 31.84 | 28.83 | test_proximal_jac_w7x_with_eq_update | -0.34 % | 6.902e+03 | 6.879e+03 | -23.60 | 174.22 | 172.89 | test_proximal_freeb_jac | 0.02 % | 1.317e+04 | 1.318e+04 | 3.01 | 71.00 | 70.69 | test_proximal_freeb_jac_blocked | 0.54 % | 1.319e+04 | 1.327e+04 | 71.57 | 69.99 | 71.58 | test_proximal_freeb_jac_batched | -2.12 % | 7.644e+03 | 7.482e+03 | -161.99 | 109.36 | 108.75 | test_proximal_jac_ripple | -0.07 % | 7.328e+03 | 7.323e+03 | -5.43 | 69.62 | 69.91 | test_proximal_jac_ripple_spline | 2.08 % | 3.831e+03 | 3.911e+03 | 79.84 | 77.24 | 77.45 |For the memory plots, go to the summary of
Memory Benchmarksworkflow and download the artifact.These benchmarks don't have any optimization, so this shouldn't have any effect. If you want, I can add an optimization test with only 1 step.
Yea that might be useful. Just for lsq exact I'd say
I will add a small tip to decide what jac chunk size will make a difference. One can use the stdout at the beginning of the oprimization Number of parameters: .... to decide min chunk size that will help reducing memory
@ddudt @rahulgaur104
I approve but I wrote half so can't actually approve
Me the same @f0uriest @ddudt