DESC
DESC copied to clipboard
Dp/jacobian batched vmap
This works well, this is LMN18 equilibrium solve with 1.5 oversampled grid and maxiter=10 memory trace vs time on GPU, where we get 4x memory decrease with negligible runtime increase:
Currently uses netket package for its chunked_vmap function, we don't want this as a dependency though so will try to implement a lighter weight version ourselves.
TODO
- [ ] re-implement without relying on
netket - [ ] change chunk_size to a better default value (something like 100 would be fine, maybe can dynamically choose based off of size of
dim_x) - [ ] Add
chunk_sizeargument to every Objective class - [ ] Add
"chunked"as a deriv_mode toDerivative(or, just as an argument toDerivativeto be used when"batched"is used) - [ ] add to singular integral calculation as well
Resolves #826
Codecov Report
Attention: Patch coverage is 89.44099% with 17 lines in your changes missing coverage. Please review.
Project coverage is 92.19%. Comparing base (
9f33691) to head (3e99510). Report is 1889 commits behind head on master.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| desc/batching.py | 85.83% | 17 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## master #1052 +/- ##
==========================================
- Coverage 95.30% 92.19% -3.12%
==========================================
Files 95 96 +1
Lines 23944 23560 -384
==========================================
- Hits 22821 21721 -1100
- Misses 1123 1839 +716
| Files with missing lines | Coverage Δ | |
|---|---|---|
| desc/continuation.py | 93.26% <100.00%> (ø) |
|
| desc/derivatives.py | 92.85% <100.00%> (-1.97%) |
:arrow_down: |
| desc/objectives/_bootstrap.py | 97.14% <ø> (ø) |
|
| desc/objectives/_coils.py | 99.17% <ø> (+0.66%) |
:arrow_up: |
| desc/objectives/_equilibrium.py | 94.53% <ø> (-0.43%) |
:arrow_down: |
| desc/objectives/_free_boundary.py | 82.83% <ø> (-14.20%) |
:arrow_down: |
| desc/objectives/_generic.py | 67.76% <ø> (-29.76%) |
:arrow_down: |
| desc/objectives/_geometry.py | 96.93% <ø> (ø) |
|
| desc/objectives/_omnigenity.py | 96.30% <ø> (ø) |
|
| desc/objectives/_power_balance.py | 87.50% <ø> (-2.09%) |
:arrow_down: |
| ... and 7 more |
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | +3.16 +/- 3.96 | +1.69e-02 +/- 2.12e-02 | 5.53e-01 +/- 2.0e-02 | 5.36e-01 +/- 6.8e-03 |
test_equilibrium_init_medres | -0.44 +/- 5.40 | -1.94e-02 +/- 2.36e-01 | 4.34e+00 +/- 1.0e-01 | 4.36e+00 +/- 2.1e-01 |
test_equilibrium_init_highres | -0.74 +/- 2.41 | -4.25e-02 +/- 1.39e-01 | 5.73e+00 +/- 1.2e-01 | 5.77e+00 +/- 6.4e-02 |
test_objective_compile_dshape_current | -1.42 +/- 1.53 | -5.72e-02 +/- 6.15e-02 | 3.97e+00 +/- 5.0e-02 | 4.03e+00 +/- 3.6e-02 |
test_objective_compute_dshape_current | -1.97 +/- 3.73 | -7.30e-05 +/- 1.39e-04 | 3.64e-03 +/- 4.4e-05 | 3.71e-03 +/- 1.3e-04 |
test_objective_jac_dshape_current | -0.67 +/- 4.78 | -2.76e-04 +/- 1.96e-03 | 4.08e-02 +/- 1.4e-03 | 4.11e-02 +/- 1.3e-03 |
test_perturb_2 | +0.42 +/- 3.47 | +7.51e-02 +/- 6.14e-01 | 1.78e+01 +/- 5.1e-01 | 1.77e+01 +/- 3.5e-01 |
test_proximal_freeb_jac | -0.28 +/- 1.56 | -2.12e-02 +/- 1.17e-01 | 7.51e+00 +/- 7.8e-02 | 7.53e+00 +/- 8.8e-02 |
test_solve_fixed_iter | +0.40 +/- 57.53 | +2.00e-02 +/- 2.88e+00 | 5.03e+00 +/- 2.0e+00 | 5.01e+00 +/- 2.1e+00 |
test_build_transform_fft_midres | -0.76 +/- 5.52 | -4.77e-03 +/- 3.46e-02 | 6.23e-01 +/- 1.1e-02 | 6.28e-01 +/- 3.3e-02 |
test_build_transform_fft_highres | -0.34 +/- 3.28 | -3.46e-03 +/- 3.36e-02 | 1.02e+00 +/- 9.3e-03 | 1.02e+00 +/- 3.2e-02 |
test_equilibrium_init_lowres | +1.50 +/- 3.83 | +5.83e-02 +/- 1.49e-01 | 3.95e+00 +/- 1.5e-01 | 3.89e+00 +/- 3.4e-02 |
test_objective_compile_atf | -0.08 +/- 4.11 | -6.25e-03 +/- 3.25e-01 | 7.90e+00 +/- 2.4e-01 | 7.91e+00 +/- 2.2e-01 |
test_objective_compute_atf | +2.00 +/- 2.81 | +2.10e-04 +/- 2.97e-04 | 1.07e-02 +/- 2.5e-04 | 1.05e-02 +/- 1.5e-04 |
test_objective_jac_atf | +1.18 +/- 2.10 | +2.33e-02 +/- 4.16e-02 | 2.00e+00 +/- 3.0e-02 | 1.98e+00 +/- 2.8e-02 |
test_perturb_1 | +7.72 +/- 3.83 | +9.70e-01 +/- 4.81e-01 | 1.35e+01 +/- 4.2e-01 | 1.26e+01 +/- 2.3e-01 |
test_proximal_jac_atf | +1.08 +/- 0.76 | +8.87e-02 +/- 6.27e-02 | 8.29e+00 +/- 4.7e-02 | 8.20e+00 +/- 4.1e-02 |
test_proximal_freeb_compute | +2.84 +/- 1.08 | +5.27e-03 +/- 2.00e-03 | 1.91e-01 +/- 1.8e-03 | 1.86e-01 +/- 9.6e-04 |
You might already be aware but fyi: https://github.com/google/jax/pull/19614
If you don't care about jax's native multi-GPU sharding support it should be easy to just vendor our implementation.
In that case, you can just vendor our netket/jax/_chunk_utils.py, netket/jax/_scanmap.py and netket/jax/_vmap_chunked.py .
The former 2 files are on purpose standalone. Only _vmap_chunked depends on other things but only if you are using sharding.
Remove all branches hitting of axis_0_is_sharded == True and config.netket_experimental_sharding, which will allow you to remove sharding_decorator, which is a mess only needed to support efficient sharding of jax arrays.
Also replace HashablePartial with functools.partial
@dpanici jax batched vmap has been merged to master
@dpanici make separate branch with the implementation using JAX's version, and in this PR implement the one based off of netket
@kianorr @YigitElma
Can you also update adding objectives docs for chunk size option
If you don't care about jax's native multi-GPU sharding support it should be easy to just vendor our implementation. In that case, you can just vendor our
netket/jax/_chunk_utils.py,netket/jax/_scanmap.pyandnetket/jax/_vmap_chunked.py.The former 2 files are on purpose standalone. Only
_vmap_chunkeddepends on other things but only if you are using sharding.Remove all branches hitting of
axis_0_is_sharded == Trueandconfig.netket_experimental_sharding, which will allow you to removesharding_decorator, which is a mess only needed to support efficient sharding of jax arrays.Also replace
HashablePartialwithfunctools.partial
@PhilipVinc Thank you very much! This was extremely helpful.
How would you like us to credit your package in our code where we have used these functions? I currently mention the original filename and the netket package, but I am happy to credit it in whatever way you prefer
Glad I could help.
Standard practice is to include a copyright notice along the lines of
# The following section of this code is a derived from the NetKet project
# https://github.com/netket/netket/blob/9881c9fb217a2ac4dc9274a054bf6e6a2993c519/netket/jax/_vmap_chunked.py
#
# The original copyright notice is as follows
# Copyright 2021 The NetKet Authors - All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
#
which would be more than fine.
I'm not terribly picky though. Do not worry.
Slightly tangential comment:
deriv_mode="batched" appears to compute the Jacobian wrt only the optimization variables. deriv_mode="blocked" however, appears to compute each objective's Jacobian wrt all parameters related to that objective (even if they are fixed by constraints). This is problematic for the ExternalObjective, since it requires extra finite differences, which are generally expensive. Is there a way to remedy that?
Slightly tangential comment:
deriv_mode="batched"appears to compute the Jacobian wrt only the optimization variables.deriv_mode="blocked"however, appears to compute each objective's Jacobian wrt all parameters related to that objective (even if they are fixed by constraints). This is problematic for theExternalObjective, since it requires extra finite differences, which are generally expensive. Is there a way to remedy that?
In theory yes. I think what we'd need is a "blocked jvp" where it computes the total jvp as a sum of smaller jvps (then LinearConstraintProjection will only take jvps in the directions needed). The main reason I didn't do this is that the "blocked" logic will use forward or reverse mode as appropriate, while "jvp" usually implies forward mode only, so if we use reverse mode when doing a jvp that might cause some issues in places where we assume jvp/vjp mean forward/reverse only. This isn't a dealbreaker, just something we need to think about a bit more.
Some laptop benchmarking on the ATF jacobian benchmark example
The effect on compute time is more thane expected, I think I will implement something that only automatically chooses a smaller size if the expected memory is larger than the available device memory. I will re-do these tests on the gpu if I can as well, this was just on cpu on mac
I will re-do these tests on the gpu if I can as well, this was just on cpu on mac
Can you also add the benchmarks for jnp.vectorize? To be sure that calling batched_vectorize takes the same time as the original JAX one.
I think I will implement something that only automatically chooses a smaller size if the expected memory is larger than the available device memory.
We can do this based on predefined values for dim_x
I will re-do these tests on the gpu if I can as well, this was just on cpu on mac
Can you also add the benchmarks for
jnp.vectorize? To be sure that callingbatched_vectorizetakes the same time as the original JAX one. yea they are basically the same as having a chunk size equal to dim_x
based off these scalings for memory usage of jacobian on GPU and CPU for the ATF benchmark, I will have "auto" estimate mem usage based off of dim_x * dim_f and apply a chunk_size to bring the estimated usage down to below the device size
a conservative estimate for the actual peak memory usage over the estimated peak memory usage according to the above formula, as a function of normalized chunk_size
Putting it together, we have something like
chunk_size < (device_mem / estimated_mem - b) / a * dim_x
where a ~ 0.8 and b ~ 0.15, and device_mem is the current host device available memory (as given by desc_config.get("avail_mem") for example)
when the RHS of this inequality is <1, that means we estimate that we will get an OOM error as even with a chunk size of 1, we cannot fit the jacobian's fwd mode calculation in the device memory
@dpanici Hi, It is irrelevant to this discussion, but what script you used to create the plot pasted on top of the PR?
@dpanici Hi, It is irrelevant to this discussion, but what script you used to create the plot pasted on top of the PR?
Hi @Qazalbash , Sorry for the delayed response. I will share the script in 1 or 2 days (this is a self-reminder😄)
In the mean time, either @rahulgaur104 or @dpanici if you have the script in your archieve, can you share it? It was the one that used pynvml I guess.
Hi @Qazalbash
This should do the job, more or less. Replace DESC_profiler_new_13-04.py with the script that you are running. Good luck!
#!/usr/bin/env python3
import subprocess
import time
import threading
import matplotlib.pyplot as plt
def monitor_vram(duration, interval, vram_usage_list, timestamps):
end_time = time.time() + duration
while time.time() < end_time:
result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used', '--format=csv,noheader,nounits'],
stdout=subprocess.PIPE)
output = result.stdout.decode('utf-8').strip()
vram_usage = int(output.split()[0])
vram_usage_list.append(vram_usage)
timestamps.append(time.time())
time.sleep(interval)
if __name__ == "__main__":
duration = 300 #duration to monitor in seconds
interval = 0.01 #interval between checks in seconds
vram_usage_list = []
timestamps = []
#create threads for monitoring VRAM and running GPU code
vram_thread = threading.Thread(target=monitor_vram, args=(duration, interval, vram_usage_list, timestamps))
#start the threads
vram_thread.start()
res = 12
#run without blocking (Popen)
subprocess.Popen(['python', "DESC_profiler_new_13-04.py", f"{res}"])
#subprocess.Popen(['python', "DESC_profiler_new_06-04.py", f"{res}"])
#wait for the thread to finish
vram_thread.join()
#write the VRAM usage to a file
with open('vram_usage.txt', 'w') as file:
for usage in vram_usage_list:
file.write(f"{usage}\n")
plt.figure(figsize=(15, 7))
#plot the VRAM usage
plt.plot([t - timestamps[0] for t in timestamps], vram_usage_list, '-or', ms=2)
plt.xlabel('Time (s)', fontsize=20)
plt.ylabel('VRAM Usage (MiB)', fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.title('GPU VRAM Usage Over Time')
plt.grid(True)
plt.tight_layout()
#plt.savefig(f"test0_QP_{res}.png", dpi=400)
plt.savefig(f"test0_QP_{res}_lowp.png", dpi=400)
#plt.show()
Thanks @YigitElma and @rahulgaur104
You're welcome! Also, remember that you have to set the environment variable
XLA_PYTHON_CLIENT_ALLOCATOR=platform
to prevent JAX from allocating 75% of your memory, otherwise, you will get a flat line.
PS: On command line, for setting an environment variable you should use export like
export XLA_PYTHON_CLIENT_ALLOCATOR=platform
to unset it, you should use
unset XLA_PYTHON_CLIENT_ALLOCATOR
I really appreciate that! I will mention these in the docs.