DESC icon indicating copy to clipboard operation
DESC copied to clipboard

Reduce storage of big arrays in memory during optimization

Open dpanici opened this issue 7 months ago • 18 comments

Reuses the J variable in optimizers to reduce the amount of Jacobian-sized matrices held in memory. This can lead to tangible memory decreases (at L=M=N=16 fixed bdry solve with default grids, J is 4GB) which can be observed after the significant memory jump after the first Jacobian calculation.

  • Updates docs to give better memory usage insight
  • Fixes some benchmark run conditions. This should make the memory profiler run with fork pull requests, also run benchmark when run_benchmarks label is added.

dpanici avatar Apr 14 '25 21:04 dpanici

Some vram profiling on my laptop GPU with 12GB memory. The used script:

import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../"))


os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

from desc import set_device

set_device("gpu")

from desc.backend import print_backend_info
from desc.examples import get
from desc.objectives import ObjectiveFunction, ForceBalance

print_backend_info()

N = 14

eq = get("precise_QA")
eq.change_resolution(L=N, M=N, N=N, L_grid=2 * N, M_grid=2 * N, N_grid=2 * N)
eq.resolution_summary()
eq.set_initial_guess()
obj = ObjectiveFunction(ForceBalance(eq), jac_chunk_size=500, deriv_mode="batched")
obj.build()
print(f"Objective function deriv mode: {obj._deriv_mode}")
print(f"Objective function chunk size: {obj._jac_chunk_size}")

# Jacobian size is 4375x25650, which is ~0.8 GB

eq.solve(
    objective=obj,
    constraints=None,
    optimizer="lsq-exact",
    ftol=1e-4,
    xtol=1e-6,
    gtol=1e-6,
    maxiter=5,
    x_scale="auto",
    verbose=2,
    copy=False,
)

Master

test-case-res14-master

PR

For clarity, I deleted the previous result with only J=J*d trick.

test-case-res14-with1688-delQR

mem analysis

YigitElma avatar Apr 14 '25 21:04 YigitElma

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     +4.12 +/- 4.44     | +2.39e-02 +/- 2.58e-02 |  6.04e-01 +/- 2.1e-02  |  5.80e-01 +/- 1.5e-02  |
 test_equilibrium_init_medres            |     +3.59 +/- 5.19     | +1.66e-01 +/- 2.40e-01 |  4.78e+00 +/- 1.9e-01  |  4.62e+00 +/- 1.5e-01  |
 test_equilibrium_init_highres           |     +2.21 +/- 3.42     | +1.15e-01 +/- 1.78e-01 |  5.30e+00 +/- 1.4e-01  |  5.19e+00 +/- 1.1e-01  |
 test_objective_compile_dshape_current   |     +1.52 +/- 1.00     | +5.49e-02 +/- 3.62e-02 |  3.67e+00 +/- 2.8e-02  |  3.61e+00 +/- 2.3e-02  |
 test_objective_compute_dshape_current   |     -0.30 +/- 2.54     | -1.05e-05 +/- 8.79e-05 |  3.45e-03 +/- 7.2e-05  |  3.46e-03 +/- 5.1e-05  |
 test_objective_jac_dshape_current       |     -0.43 +/- 14.99    | -1.34e-04 +/- 4.68e-03 |  3.11e-02 +/- 3.4e-03  |  3.12e-02 +/- 3.2e-03  |
 test_perturb_2                          |     +1.54 +/- 3.37     | +2.79e-01 +/- 6.13e-01 |  1.84e+01 +/- 5.8e-01  |  1.82e+01 +/- 1.9e-01  |
 test_proximal_jac_atf_with_eq_update    |     +0.40 +/- 0.75     | +6.04e-02 +/- 1.12e-01 |  1.51e+01 +/- 9.0e-02  |  1.51e+01 +/- 6.7e-02  |
 test_proximal_freeb_jac                 |     +2.92 +/- 9.20     | +1.43e-01 +/- 4.51e-01 |  5.04e+00 +/- 3.1e-01  |  4.90e+00 +/- 3.2e-01  |
 test_solve_fixed_iter_compiled          |     -1.30 +/- 2.53     | -2.37e-01 +/- 4.61e-01 |  1.80e+01 +/- 1.2e-01  |  1.82e+01 +/- 4.5e-01  |
 test_LinearConstraintProjection_build   |     -0.67 +/- 2.68     | -5.86e-02 +/- 2.34e-01 |  8.69e+00 +/- 2.2e-01  |  8.75e+00 +/- 7.9e-02  |
 test_objective_compute_ripple_spline    |     -1.90 +/- 5.08     | -6.22e-03 +/- 1.66e-02 |  3.22e-01 +/- 9.1e-03  |  3.28e-01 +/- 1.4e-02  |
 test_objective_grad_ripple_spline       |     -0.44 +/- 3.85     | -5.52e-03 +/- 4.79e-02 |  1.24e+00 +/- 3.9e-02  |  1.25e+00 +/- 2.7e-02  |
 test_build_transform_fft_midres         |     +2.40 +/- 4.85     | +1.69e-02 +/- 3.42e-02 |  7.21e-01 +/- 2.1e-02  |  7.04e-01 +/- 2.7e-02  |
 test_build_transform_fft_highres        |     +0.67 +/- 4.47     | +6.52e-03 +/- 4.38e-02 |  9.86e-01 +/- 4.0e-02  |  9.80e-01 +/- 1.8e-02  |
 test_equilibrium_init_lowres            |     +3.83 +/- 2.48     | +1.57e-01 +/- 1.01e-01 |  4.24e+00 +/- 8.8e-02  |  4.08e+00 +/- 5.0e-02  |
 test_objective_compile_atf              |     -1.22 +/- 4.60     | -7.79e-02 +/- 2.94e-01 |  6.32e+00 +/- 1.5e-01  |  6.40e+00 +/- 2.5e-01  |
 test_objective_compute_atf              |     +6.08 +/- 14.24    | +5.13e-04 +/- 1.20e-03 |  8.95e-03 +/- 1.2e-03  |  8.44e-03 +/- 1.4e-04  |
 test_objective_jac_atf                  |     +2.44 +/- 3.44     | +3.74e-02 +/- 5.27e-02 |  1.57e+00 +/- 4.0e-02  |  1.53e+00 +/- 3.4e-02  |
 test_perturb_1                          |     +3.21 +/- 5.48     | +4.56e-01 +/- 7.77e-01 |  1.46e+01 +/- 6.9e-01  |  1.42e+01 +/- 3.6e-01  |
 test_proximal_jac_atf                   |     -1.38 +/- 3.07     | -1.06e-01 +/- 2.37e-01 |  7.59e+00 +/- 1.1e-01  |  7.70e+00 +/- 2.1e-01  |
 test_proximal_freeb_compute             |     +0.59 +/- 3.42     | +1.05e-03 +/- 6.12e-03 |  1.80e-01 +/- 5.3e-03  |  1.79e-01 +/- 3.0e-03  |
 test_solve_fixed_iter                   |     -1.10 +/- 3.52     | -3.33e-01 +/- 1.07e+00 |  3.01e+01 +/- 3.7e-01  |  3.04e+01 +/- 1.0e+00  |
 test_objective_compute_ripple           |     +0.09 +/- 1.23     | +2.36e-03 +/- 3.37e-02 |  2.75e+00 +/- 2.7e-02  |  2.75e+00 +/- 2.1e-02  |
 test_objective_grad_ripple              |     +1.07 +/- 2.49     | +5.36e-02 +/- 1.24e-01 |  5.05e+00 +/- 5.5e-02  |  4.99e+00 +/- 1.1e-01  |

github-actions[bot] avatar Apr 14 '25 22:04 github-actions[bot]

Codecov Report

:white_check_mark: All modified and coverable lines are covered by tests. :white_check_mark: Project coverage is 95.79%. Comparing base (f580c7a) to head (7a461e2). :warning: Report is 273 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1688      +/-   ##
==========================================
+ Coverage   95.67%   95.79%   +0.12%     
==========================================
  Files         101      101              
  Lines       26731    26753      +22     
==========================================
+ Hits        25575    25629      +54     
+ Misses       1156     1124      -32     
Files with missing lines Coverage Δ
desc/optimize/aug_lagrangian.py 97.05% <100.00%> (+0.07%) :arrow_up:
desc/optimize/aug_lagrangian_ls.py 95.83% <100.00%> (+0.09%) :arrow_up:
desc/optimize/fmin_scalar.py 98.18% <100.00%> (+0.06%) :arrow_up:
desc/optimize/least_squares.py 99.36% <100.00%> (+0.02%) :arrow_up:

... and 3 files with indirect coverage changes

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Apr 14 '25 23:04 codecov[bot]

To be clear though, this doesn't resolve the memory leak issue #1686 but prevents storing big arrays at the same time.

YigitElma avatar Apr 14 '25 23:04 YigitElma

Yeah I don't see how this actually fixes the underlying issue in #1686. From the plots above it looks like this doesn't get rid of those spikes which are likely the cause of the OOM after a few iterations, we should figure out what's actually causing that.

f0uriest avatar Apr 15 '25 01:04 f0uriest

Also the plots only show a reduction of ~500 MB, shouldn't it be a lot more?

f0uriest avatar Apr 15 '25 01:04 f0uriest

@f0uriest The last plot is with the most recent changes, and there, the memory decrease is ~1.8GB. The peaks are the Jacobian evaluations, and those can be flattened out by jac_chunk_size=1 (for my test case it was 500).

I think #1686 was at the edge of the memory limit and some small leak like 20-40 MB according to @dpanici's further analysis caused the OOM. We realized these changes while looking at that issue but this PR is not for solving that issue.

YigitElma avatar Apr 15 '25 03:04 YigitElma

Do it in place garbage collection manually in btwn steps

dpanici avatar Apr 16 '25 19:04 dpanici

check for fmintr as well

dpanici avatar Apr 16 '25 19:04 dpanici

I shared some profiling https://github.com/PlasmaControl/DESC/issues/1686#issuecomment-2814372946 https://github.com/PlasmaControl/DESC/issues/1686#issuecomment-2814409983

YigitElma avatar Apr 18 '25 03:04 YigitElma

I ran the profiling again before and after Garbage collection, it looks like during the optimization there is no change (the difference is almost constant with memory difference at time t=0). And, the at[].set() doesn't change anything, I think that is never in place (except jit), see https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html.

It looks like for lower resolutions garbage collection has some speed effect, maybe we can just run it once before optimization? test-case-memory-gpucpu-bagc

Note: I should mention again, this is done on my personal laptop. Although GPU profiling is pretty isolated (normal apps don't use NVidia Vram), CPU memory is exposed to other apps.

YigitElma avatar Apr 18 '25 03:04 YigitElma

Also, just for record here are the profiling scripts.

Profiling.zip

YigitElma avatar Apr 18 '25 03:04 YigitElma

Yea maybe we don't gc if it seems to have negligible effect. The in-place stuff if it does not affect speed, we can keep in and then when #1669 is eventually done, we will actually see the benefits then.

dpanici avatar Apr 18 '25 19:04 dpanici

Memory benchmark result

|               Test Name                |      %Δ      |    Master (MB)     |      PR (MB)       |    Δ (MB)    |    Time PR (s)     |  Time Master (s)   |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
  test_objective_jac_w7x                 |    8.31 %    |     3.762e+03      |     4.075e+03      |    312.57    |       29.94        |       28.52        |
  test_proximal_jac_w7x_with_eq_update   |   -0.97 %    |     6.925e+03      |     6.857e+03      |    -67.33    |       172.99       |       170.98       |
  test_proximal_freeb_jac                |    0.06 %    |     1.317e+04      |     1.318e+04      |     7.62     |       71.69        |       74.39        |
  test_proximal_freeb_jac_blocked        |   -0.41 %    |     1.320e+04      |     1.314e+04      |    -54.11    |       70.18        |       69.37        |
  test_proximal_freeb_jac_batched        |   -0.97 %    |     7.558e+03      |     7.485e+03      |    -73.59    |       107.44       |       106.93       |
  test_proximal_jac_ripple               |    0.22 %    |     7.286e+03      |     7.303e+03      |    16.30     |       68.49        |       69.02        |
  test_proximal_jac_ripple_spline        |   -0.17 %    |     3.826e+03      |     3.820e+03      |    -6.64     |       75.41        |       75.93        |
+ test_eq_solve                          |   -11.84 %   |     2.249e+03      |     1.983e+03      |   -266.44    |       122.44       |       122.35       |

For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.

github-actions[bot] avatar May 01 '25 03:05 github-actions[bot]

Memory benchmark result

|               Test Name                |      %Δ      |    Master (MB)     |      PR (MB)       |    Δ (MB)    |    Time PR (s)     |  Time Master (s)   |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
  test_objective_jac_w7x                 |   -0.15 %    |     3.945e+03      |     3.939e+03      |    -6.04     |       31.84        |       28.83        |
  test_proximal_jac_w7x_with_eq_update   |   -0.34 %    |     6.902e+03      |     6.879e+03      |    -23.60    |       174.22       |       172.89       |
  test_proximal_freeb_jac                |    0.02 %    |     1.317e+04      |     1.318e+04      |     3.01     |       71.00        |       70.69        |
  test_proximal_freeb_jac_blocked        |    0.54 %    |     1.319e+04      |     1.327e+04      |    71.57     |       69.99        |       71.58        |
  test_proximal_freeb_jac_batched        |   -2.12 %    |     7.644e+03      |     7.482e+03      |   -161.99    |       109.36       |       108.75       |
  test_proximal_jac_ripple               |   -0.07 %    |     7.328e+03      |     7.323e+03      |    -5.43     |       69.62        |       69.91        |
  test_proximal_jac_ripple_spline        |    2.08 %    |     3.831e+03      |     3.911e+03      |    79.84     |       77.24        |       77.45        |

For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.

These benchmarks don't have any optimization, so this shouldn't have any effect. If you want, I can add an optimization test with only 1 step.

YigitElma avatar May 01 '25 03:05 YigitElma

Memory benchmark result

|               Test Name                |      %Δ      |    Master (MB)     |      PR (MB)       |    Δ (MB)    |    Time PR (s)     |  Time Master (s)   |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
  test_objective_jac_w7x                 |   -0.15 %    |     3.945e+03      |     3.939e+03      |    -6.04     |       31.84        |       28.83        |
  test_proximal_jac_w7x_with_eq_update   |   -0.34 %    |     6.902e+03      |     6.879e+03      |    -23.60    |       174.22       |       172.89       |
  test_proximal_freeb_jac                |    0.02 %    |     1.317e+04      |     1.318e+04      |     3.01     |       71.00        |       70.69        |
  test_proximal_freeb_jac_blocked        |    0.54 %    |     1.319e+04      |     1.327e+04      |    71.57     |       69.99        |       71.58        |
  test_proximal_freeb_jac_batched        |   -2.12 %    |     7.644e+03      |     7.482e+03      |   -161.99    |       109.36       |       108.75       |
  test_proximal_jac_ripple               |   -0.07 %    |     7.328e+03      |     7.323e+03      |    -5.43     |       69.62        |       69.91        |
  test_proximal_jac_ripple_spline        |    2.08 %    |     3.831e+03      |     3.911e+03      |    79.84     |       77.24        |       77.45        |

For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.

These benchmarks don't have any optimization, so this shouldn't have any effect. If you want, I can add an optimization test with only 1 step.

Yea that might be useful. Just for lsq exact I'd say

dpanici avatar May 01 '25 13:05 dpanici

I will add a small tip to decide what jac chunk size will make a difference. One can use the stdout at the beginning of the oprimization Number of parameters: .... to decide min chunk size that will help reducing memory

YigitElma avatar May 07 '25 14:05 YigitElma

@ddudt @rahulgaur104

YigitElma avatar May 08 '25 02:05 YigitElma

I approve but I wrote half so can't actually approve

Me the same @f0uriest @ddudt

YigitElma avatar May 12 '25 20:05 YigitElma