DESC
DESC copied to clipboard
Parallelization via sharding (multi-device) composed with vmap (across memory in single device)
This is the continuation of #1869. Broadly, there are two methods of parallelizing computation.
-
The first is to split the computation and parallelize across the memory of a single device. This is ideal if the problem can fit on a single device and if the device has sufficient processing power to complete the computation in parallel.
-
The second is to split the computation and parallelize across multiple devices, which is useful when the problem size becomes too large for the first strategy.
In JAX, option 1 is done with vmap. When the problem size becomes larger, we mimic the benefits of option 2 with vmap_chunked. With vmap_chunked we choose the chunk_size such that the italic part of option 1 is best satisfied and then compute on each chunk in sequence.
This PR composes multiple device parallelization with vmap_chunked to improve compute ability.
With the current JAX sharding API, it is most practical to first split the data S across n devices, then on each subset of data of size S/n we may chunk with size m across the memory of a single device with vmap_chunked. So each device is computing on a subset of size S/n/m simultaneously. These API require S to be divisible by n and m, but it is just some simple extra logic to compute the remainder of S/n/m on each device after they finish the divisible portion, and likewise compute the remainder of S/n on a single device after that.
The net result will be reducing the compute time. This also ensures that if there was a memory bottlneck in constructing the Jacobian, then this bottleneck will be pushed away toward storing the Jacobian. Then that can later be avoided by using iterative subspace methods so that only a subset of the Jacobian needs to be stored at a time.
If the use case is not clear, then explicitly, this would enable parallelization without rebuilding an objective for each flux surface/duplcating computation and transforms.
Also relevant: https://github.com/PlasmaControl/DESC/pull/1773#issuecomment-2981490799
Memory benchmark result
| Test Name | %Δ | Master (MB) | PR (MB) | Δ (MB) | Time PR (s) | Time Master (s) |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
test_objective_jac_w7x | 2.42 % | 3.847e+03 | 3.941e+03 | 93.20 | 41.59 | 37.36 |
test_proximal_jac_w7x_with_eq_update | -2.61 % | 6.627e+03 | 6.453e+03 | -173.12 | 166.45 | 164.96 |
test_proximal_freeb_jac | -0.21 % | 1.319e+04 | 1.316e+04 | -27.98 | 86.57 | 84.74 |
test_proximal_freeb_jac_blocked | 0.36 % | 7.464e+03 | 7.491e+03 | 26.84 | 74.51 | 74.72 |
test_proximal_freeb_jac_batched | 0.01 % | 7.485e+03 | 7.486e+03 | 0.46 | 73.91 | 74.62 |
test_proximal_jac_ripple | -1.53 % | 3.483e+03 | 3.430e+03 | -53.27 | 67.71 | 67.51 |
test_proximal_jac_ripple_bounce1d | 0.53 % | 3.550e+03 | 3.569e+03 | 18.82 | 78.96 | 78.26 |
test_eq_solve | -0.94 % | 2.021e+03 | 2.002e+03 | -18.97 | 96.49 | 94.79 |
For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | -1.03 +/- 2.90 | -5.93e-03 +/- 1.66e-02 | 5.68e-01 +/- 1.2e-02 | 5.74e-01 +/- 1.2e-02 |
test_equilibrium_init_medres | -0.82 +/- 2.17 | -4.17e-02 +/- 1.10e-01 | 5.02e+00 +/- 7.0e-02 | 5.06e+00 +/- 8.5e-02 |
test_equilibrium_init_highres | -0.34 +/- 1.52 | -1.94e-02 +/- 8.54e-02 | 5.61e+00 +/- 6.0e-02 | 5.63e+00 +/- 6.1e-02 |
test_objective_compile_dshape_current | +0.59 +/- 3.17 | +1.99e-02 +/- 1.07e-01 | 3.40e+00 +/- 9.6e-02 | 3.38e+00 +/- 4.7e-02 |
test_objective_compute_dshape_current | -4.53 +/- 6.85 | -3.51e-05 +/- 5.30e-05 | 7.39e-04 +/- 3.5e-05 | 7.74e-04 +/- 3.9e-05 |
test_objective_jac_dshape_current | -1.16 +/- 18.84 | -3.81e-04 +/- 6.21e-03 | 3.26e-02 +/- 4.1e-03 | 3.30e-02 +/- 4.7e-03 |
test_perturb_2 | +0.60 +/- 1.67 | +1.02e-01 +/- 2.85e-01 | 1.71e+01 +/- 1.1e-01 | 1.70e+01 +/- 2.6e-01 |
test_proximal_jac_atf_with_eq_update | +0.70 +/- 0.89 | +9.45e-02 +/- 1.21e-01 | 1.37e+01 +/- 9.8e-02 | 1.36e+01 +/- 7.1e-02 |
test_proximal_freeb_jac | +0.46 +/- 1.81 | +2.32e-02 +/- 9.02e-02 | 5.02e+00 +/- 6.6e-02 | 5.00e+00 +/- 6.2e-02 |
test_solve_fixed_iter_compiled | +0.31 +/- 1.80 | +5.34e-02 +/- 3.08e-01 | 1.72e+01 +/- 2.7e-01 | 1.71e+01 +/- 1.5e-01 |
test_LinearConstraintProjection_build | -1.08 +/- 2.23 | -9.17e-02 +/- 1.90e-01 | 8.42e+00 +/- 1.2e-01 | 8.51e+00 +/- 1.5e-01 |
test_objective_compute_ripple_spline | +1.20 +/- 3.65 | +3.53e-03 +/- 1.07e-02 | 2.98e-01 +/- 8.6e-03 | 2.94e-01 +/- 6.5e-03 |
test_objective_grad_ripple_spline | -0.41 +/- 3.84 | -4.59e-03 +/- 4.29e-02 | 1.11e+00 +/- 3.8e-02 | 1.12e+00 +/- 1.9e-02 |
Relevant documentation for us.
- https://github.com/orgs/netket/discussions/2063
- https://github.com/netket/netket/pull/2059#issue-3114532856
- https://github.com/netket/netket/blob/pv/jax-0.6/netket/jax/_map.py
Codecov Report
:x: Patch coverage is 30.43478% with 16 lines in your changes missing coverage. Please review.
:white_check_mark: Project coverage is 95.72%. Comparing base (a846c9b) to head (f8485c2).
| Files with missing lines | Patch % | Lines |
|---|---|---|
| desc/batching.py | 30.43% | 16 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## master #1773 +/- ##
==========================================
- Coverage 95.77% 95.72% -0.05%
==========================================
Files 101 101
Lines 27796 27816 +20
==========================================
+ Hits 26622 26628 +6
- Misses 1174 1188 +14
| Files with missing lines | Coverage Δ | |
|---|---|---|
| desc/batching.py | 75.93% <30.43%> (-6.11%) |
:arrow_down: |
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
@PlasmaControl/desc-dev we can merge this unless someone wants to take over.
We will keep as draft until we implement a use case (either on this PR or on a PR pointed to this one)
- Vmap changes could still be put into master, make new PR for this
We will keep as draft until we implement a use case (either on this PR or on a PR pointed to this one)
The use case is parallelization without rebuilding an objective for each flux surface.