DESC icon indicating copy to clipboard operation
DESC copied to clipboard

Parallelization via sharding (multi-device) composed with vmap (across memory in single device)

Open unalmis opened this issue 5 months ago • 8 comments

This is the continuation of #1869. Broadly, there are two methods of parallelizing computation.

  1. The first is to split the computation and parallelize across the memory of a single device. This is ideal if the problem can fit on a single device and if the device has sufficient processing power to complete the computation in parallel.

  2. The second is to split the computation and parallelize across multiple devices, which is useful when the problem size becomes too large for the first strategy.

In JAX, option 1 is done with vmap. When the problem size becomes larger, we mimic the benefits of option 2 with vmap_chunked. With vmap_chunked we choose the chunk_size such that the italic part of option 1 is best satisfied and then compute on each chunk in sequence.

This PR composes multiple device parallelization with vmap_chunked to improve compute ability.

With the current JAX sharding API, it is most practical to first split the data S across n devices, then on each subset of data of size S/n we may chunk with size m across the memory of a single device with vmap_chunked. So each device is computing on a subset of size S/n/m simultaneously. These API require S to be divisible by n and m, but it is just some simple extra logic to compute the remainder of S/n/m on each device after they finish the divisible portion, and likewise compute the remainder of S/n on a single device after that.

The net result will be reducing the compute time. This also ensures that if there was a memory bottlneck in constructing the Jacobian, then this bottleneck will be pushed away toward storing the Jacobian. Then that can later be avoided by using iterative subspace methods so that only a subset of the Jacobian needs to be stored at a time.

If the use case is not clear, then explicitly, this would enable parallelization without rebuilding an objective for each flux surface/duplcating computation and transforms.

Also relevant: https://github.com/PlasmaControl/DESC/pull/1773#issuecomment-2981490799

unalmis avatar Jun 11 '25 20:06 unalmis

Memory benchmark result

|               Test Name                |      %Δ      |    Master (MB)     |      PR (MB)       |    Δ (MB)    |    Time PR (s)     |  Time Master (s)   |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
  test_objective_jac_w7x                 |    2.42 %    |     3.847e+03      |     3.941e+03      |    93.20     |       41.59        |       37.36        |
  test_proximal_jac_w7x_with_eq_update   |   -2.61 %    |     6.627e+03      |     6.453e+03      |   -173.12    |       166.45       |       164.96       |
  test_proximal_freeb_jac                |   -0.21 %    |     1.319e+04      |     1.316e+04      |    -27.98    |       86.57        |       84.74        |
  test_proximal_freeb_jac_blocked        |    0.36 %    |     7.464e+03      |     7.491e+03      |    26.84     |       74.51        |       74.72        |
  test_proximal_freeb_jac_batched        |    0.01 %    |     7.485e+03      |     7.486e+03      |     0.46     |       73.91        |       74.62        |
  test_proximal_jac_ripple               |   -1.53 %    |     3.483e+03      |     3.430e+03      |    -53.27    |       67.71        |       67.51        |
  test_proximal_jac_ripple_bounce1d      |    0.53 %    |     3.550e+03      |     3.569e+03      |    18.82     |       78.96        |       78.26        |
  test_eq_solve                          |   -0.94 %    |     2.021e+03      |     2.002e+03      |    -18.97    |       96.49        |       94.79        |

For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.

github-actions[bot] avatar Jun 11 '25 20:06 github-actions[bot]

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     -1.03 +/- 2.90     | -5.93e-03 +/- 1.66e-02 |  5.68e-01 +/- 1.2e-02  |  5.74e-01 +/- 1.2e-02  |
 test_equilibrium_init_medres            |     -0.82 +/- 2.17     | -4.17e-02 +/- 1.10e-01 |  5.02e+00 +/- 7.0e-02  |  5.06e+00 +/- 8.5e-02  |
 test_equilibrium_init_highres           |     -0.34 +/- 1.52     | -1.94e-02 +/- 8.54e-02 |  5.61e+00 +/- 6.0e-02  |  5.63e+00 +/- 6.1e-02  |
 test_objective_compile_dshape_current   |     +0.59 +/- 3.17     | +1.99e-02 +/- 1.07e-01 |  3.40e+00 +/- 9.6e-02  |  3.38e+00 +/- 4.7e-02  |
 test_objective_compute_dshape_current   |     -4.53 +/- 6.85     | -3.51e-05 +/- 5.30e-05 |  7.39e-04 +/- 3.5e-05  |  7.74e-04 +/- 3.9e-05  |
 test_objective_jac_dshape_current       |     -1.16 +/- 18.84    | -3.81e-04 +/- 6.21e-03 |  3.26e-02 +/- 4.1e-03  |  3.30e-02 +/- 4.7e-03  |
 test_perturb_2                          |     +0.60 +/- 1.67     | +1.02e-01 +/- 2.85e-01 |  1.71e+01 +/- 1.1e-01  |  1.70e+01 +/- 2.6e-01  |
 test_proximal_jac_atf_with_eq_update    |     +0.70 +/- 0.89     | +9.45e-02 +/- 1.21e-01 |  1.37e+01 +/- 9.8e-02  |  1.36e+01 +/- 7.1e-02  |
 test_proximal_freeb_jac                 |     +0.46 +/- 1.81     | +2.32e-02 +/- 9.02e-02 |  5.02e+00 +/- 6.6e-02  |  5.00e+00 +/- 6.2e-02  |
 test_solve_fixed_iter_compiled          |     +0.31 +/- 1.80     | +5.34e-02 +/- 3.08e-01 |  1.72e+01 +/- 2.7e-01  |  1.71e+01 +/- 1.5e-01  |
 test_LinearConstraintProjection_build   |     -1.08 +/- 2.23     | -9.17e-02 +/- 1.90e-01 |  8.42e+00 +/- 1.2e-01  |  8.51e+00 +/- 1.5e-01  |
 test_objective_compute_ripple_spline    |     +1.20 +/- 3.65     | +3.53e-03 +/- 1.07e-02 |  2.98e-01 +/- 8.6e-03  |  2.94e-01 +/- 6.5e-03  |
 test_objective_grad_ripple_spline       |     -0.41 +/- 3.84     | -4.59e-03 +/- 4.29e-02 |  1.11e+00 +/- 3.8e-02  |  1.12e+00 +/- 1.9e-02  |

github-actions[bot] avatar Jun 13 '25 17:06 github-actions[bot]

Relevant documentation for us.

  1. https://github.com/orgs/netket/discussions/2063
  2. https://github.com/netket/netket/pull/2059#issue-3114532856
  3. https://github.com/netket/netket/blob/pv/jax-0.6/netket/jax/_map.py

unalmis avatar Jun 17 '25 18:06 unalmis

Codecov Report

:x: Patch coverage is 30.43478% with 16 lines in your changes missing coverage. Please review. :white_check_mark: Project coverage is 95.72%. Comparing base (a846c9b) to head (f8485c2).

Files with missing lines Patch % Lines
desc/batching.py 30.43% 16 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1773      +/-   ##
==========================================
- Coverage   95.77%   95.72%   -0.05%     
==========================================
  Files         101      101              
  Lines       27796    27816      +20     
==========================================
+ Hits        26622    26628       +6     
- Misses       1174     1188      +14     
Files with missing lines Coverage Δ
desc/batching.py 75.93% <30.43%> (-6.11%) :arrow_down:

... and 1 file with indirect coverage changes

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Jul 25 '25 19:07 codecov[bot]

@PlasmaControl/desc-dev we can merge this unless someone wants to take over.

unalmis avatar Jul 30 '25 22:07 unalmis

We will keep as draft until we implement a use case (either on this PR or on a PR pointed to this one)

dpanici avatar Aug 13 '25 20:08 dpanici

  • Vmap changes could still be put into master, make new PR for this

dpanici avatar Aug 13 '25 20:08 dpanici

We will keep as draft until we implement a use case (either on this PR or on a PR pointed to this one)

The use case is parallelization without rebuilding an objective for each flux surface.

unalmis avatar Sep 04 '25 22:09 unalmis