DESC icon indicating copy to clipboard operation
DESC copied to clipboard

Dp/jacobian batched vmap

Open dpanici opened this issue 1 year ago • 7 comments

This works well, this is LMN18 equilibrium solve with 1.5 oversampled grid and maxiter=10 memory trace vs time on GPU, where we get 4x memory decrease with negligible runtime increase:

image

Currently uses netket package for its chunked_vmap function, we don't want this as a dependency though so will try to implement a lighter weight version ourselves.

TODO

  • [ ] re-implement without relying on netket
  • [ ] change chunk_size to a better default value (something like 100 would be fine, maybe can dynamically choose based off of size of dim_x)
  • [ ] Add chunk_size argument to every Objective class
  • [ ] Add "chunked" as a deriv_mode to Derivative (or, just as an argument to Derivative to be used when "batched" is used)
  • [ ] add to singular integral calculation as well

Resolves #826

dpanici avatar Jun 13 '24 19:06 dpanici

Codecov Report

Attention: Patch coverage is 89.44099% with 17 lines in your changes missing coverage. Please review.

Project coverage is 92.19%. Comparing base (9f33691) to head (3e99510). Report is 1889 commits behind head on master.

Files with missing lines Patch % Lines
desc/batching.py 85.83% 17 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1052      +/-   ##
==========================================
- Coverage   95.30%   92.19%   -3.12%     
==========================================
  Files          95       96       +1     
  Lines       23944    23560     -384     
==========================================
- Hits        22821    21721    -1100     
- Misses       1123     1839     +716     
Files with missing lines Coverage Δ
desc/continuation.py 93.26% <100.00%> (ø)
desc/derivatives.py 92.85% <100.00%> (-1.97%) :arrow_down:
desc/objectives/_bootstrap.py 97.14% <ø> (ø)
desc/objectives/_coils.py 99.17% <ø> (+0.66%) :arrow_up:
desc/objectives/_equilibrium.py 94.53% <ø> (-0.43%) :arrow_down:
desc/objectives/_free_boundary.py 82.83% <ø> (-14.20%) :arrow_down:
desc/objectives/_generic.py 67.76% <ø> (-29.76%) :arrow_down:
desc/objectives/_geometry.py 96.93% <ø> (ø)
desc/objectives/_omnigenity.py 96.30% <ø> (ø)
desc/objectives/_power_balance.py 87.50% <ø> (-2.09%) :arrow_down:
... and 7 more

... and 24 files with indirect coverage changes

codecov[bot] avatar Jun 13 '24 20:06 codecov[bot]

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     +3.16 +/- 3.96     | +1.69e-02 +/- 2.12e-02 |  5.53e-01 +/- 2.0e-02  |  5.36e-01 +/- 6.8e-03  |
 test_equilibrium_init_medres            |     -0.44 +/- 5.40     | -1.94e-02 +/- 2.36e-01 |  4.34e+00 +/- 1.0e-01  |  4.36e+00 +/- 2.1e-01  |
 test_equilibrium_init_highres           |     -0.74 +/- 2.41     | -4.25e-02 +/- 1.39e-01 |  5.73e+00 +/- 1.2e-01  |  5.77e+00 +/- 6.4e-02  |
 test_objective_compile_dshape_current   |     -1.42 +/- 1.53     | -5.72e-02 +/- 6.15e-02 |  3.97e+00 +/- 5.0e-02  |  4.03e+00 +/- 3.6e-02  |
 test_objective_compute_dshape_current   |     -1.97 +/- 3.73     | -7.30e-05 +/- 1.39e-04 |  3.64e-03 +/- 4.4e-05  |  3.71e-03 +/- 1.3e-04  |
 test_objective_jac_dshape_current       |     -0.67 +/- 4.78     | -2.76e-04 +/- 1.96e-03 |  4.08e-02 +/- 1.4e-03  |  4.11e-02 +/- 1.3e-03  |
 test_perturb_2                          |     +0.42 +/- 3.47     | +7.51e-02 +/- 6.14e-01 |  1.78e+01 +/- 5.1e-01  |  1.77e+01 +/- 3.5e-01  |
 test_proximal_freeb_jac                 |     -0.28 +/- 1.56     | -2.12e-02 +/- 1.17e-01 |  7.51e+00 +/- 7.8e-02  |  7.53e+00 +/- 8.8e-02  |
 test_solve_fixed_iter                   |     +0.40 +/- 57.53    | +2.00e-02 +/- 2.88e+00 |  5.03e+00 +/- 2.0e+00  |  5.01e+00 +/- 2.1e+00  |
 test_build_transform_fft_midres         |     -0.76 +/- 5.52     | -4.77e-03 +/- 3.46e-02 |  6.23e-01 +/- 1.1e-02  |  6.28e-01 +/- 3.3e-02  |
 test_build_transform_fft_highres        |     -0.34 +/- 3.28     | -3.46e-03 +/- 3.36e-02 |  1.02e+00 +/- 9.3e-03  |  1.02e+00 +/- 3.2e-02  |
 test_equilibrium_init_lowres            |     +1.50 +/- 3.83     | +5.83e-02 +/- 1.49e-01 |  3.95e+00 +/- 1.5e-01  |  3.89e+00 +/- 3.4e-02  |
 test_objective_compile_atf              |     -0.08 +/- 4.11     | -6.25e-03 +/- 3.25e-01 |  7.90e+00 +/- 2.4e-01  |  7.91e+00 +/- 2.2e-01  |
 test_objective_compute_atf              |     +2.00 +/- 2.81     | +2.10e-04 +/- 2.97e-04 |  1.07e-02 +/- 2.5e-04  |  1.05e-02 +/- 1.5e-04  |
 test_objective_jac_atf                  |     +1.18 +/- 2.10     | +2.33e-02 +/- 4.16e-02 |  2.00e+00 +/- 3.0e-02  |  1.98e+00 +/- 2.8e-02  |
 test_perturb_1                          |     +7.72 +/- 3.83     | +9.70e-01 +/- 4.81e-01 |  1.35e+01 +/- 4.2e-01  |  1.26e+01 +/- 2.3e-01  |
 test_proximal_jac_atf                   |     +1.08 +/- 0.76     | +8.87e-02 +/- 6.27e-02 |  8.29e+00 +/- 4.7e-02  |  8.20e+00 +/- 4.1e-02  |
 test_proximal_freeb_compute             |     +2.84 +/- 1.08     | +5.27e-03 +/- 2.00e-03 |  1.91e-01 +/- 1.8e-03  |  1.86e-01 +/- 9.6e-04  |

github-actions[bot] avatar Jun 13 '24 21:06 github-actions[bot]

You might already be aware but fyi: https://github.com/google/jax/pull/19614

unalmis avatar Jun 14 '24 00:06 unalmis

If you don't care about jax's native multi-GPU sharding support it should be easy to just vendor our implementation. In that case, you can just vendor our netket/jax/_chunk_utils.py, netket/jax/_scanmap.py and netket/jax/_vmap_chunked.py .

The former 2 files are on purpose standalone. Only _vmap_chunked depends on other things but only if you are using sharding.

Remove all branches hitting of axis_0_is_sharded == True and config.netket_experimental_sharding, which will allow you to remove sharding_decorator, which is a mess only needed to support efficient sharding of jax arrays.

Also replace HashablePartial with functools.partial

PhilipVinc avatar Jun 17 '24 09:06 PhilipVinc

@dpanici jax batched vmap has been merged to master

dpanici avatar Jul 10 '24 20:07 dpanici

@dpanici make separate branch with the implementation using JAX's version, and in this PR implement the one based off of netket

dpanici avatar Jul 23 '24 18:07 dpanici

@kianorr @YigitElma

dpanici avatar Aug 07 '24 19:08 dpanici

Can you also update adding objectives docs for chunk size option

unalmis avatar Aug 22 '24 20:08 unalmis

If you don't care about jax's native multi-GPU sharding support it should be easy to just vendor our implementation. In that case, you can just vendor our netket/jax/_chunk_utils.py, netket/jax/_scanmap.py and netket/jax/_vmap_chunked.py .

The former 2 files are on purpose standalone. Only _vmap_chunked depends on other things but only if you are using sharding.

Remove all branches hitting of axis_0_is_sharded == True and config.netket_experimental_sharding, which will allow you to remove sharding_decorator, which is a mess only needed to support efficient sharding of jax arrays.

Also replace HashablePartial with functools.partial

@PhilipVinc Thank you very much! This was extremely helpful.

How would you like us to credit your package in our code where we have used these functions? I currently mention the original filename and the netket package, but I am happy to credit it in whatever way you prefer

dpanici avatar Aug 23 '24 14:08 dpanici

Glad I could help.

Standard practice is to include a copyright notice along the lines of

# The following section of this code is a derived from the NetKet project 
# https://github.com/netket/netket/blob/9881c9fb217a2ac4dc9274a054bf6e6a2993c519/netket/jax/_vmap_chunked.py
# 
# The original copyright notice is as follows
# Copyright 2021 The NetKet Authors - All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# 

which would be more than fine.

I'm not terribly picky though. Do not worry.

PhilipVinc avatar Aug 26 '24 14:08 PhilipVinc

Slightly tangential comment: deriv_mode="batched" appears to compute the Jacobian wrt only the optimization variables. deriv_mode="blocked" however, appears to compute each objective's Jacobian wrt all parameters related to that objective (even if they are fixed by constraints). This is problematic for the ExternalObjective, since it requires extra finite differences, which are generally expensive. Is there a way to remedy that?

ddudt avatar Aug 29 '24 16:08 ddudt

Slightly tangential comment: deriv_mode="batched" appears to compute the Jacobian wrt only the optimization variables. deriv_mode="blocked" however, appears to compute each objective's Jacobian wrt all parameters related to that objective (even if they are fixed by constraints). This is problematic for the ExternalObjective, since it requires extra finite differences, which are generally expensive. Is there a way to remedy that?

In theory yes. I think what we'd need is a "blocked jvp" where it computes the total jvp as a sum of smaller jvps (then LinearConstraintProjection will only take jvps in the directions needed). The main reason I didn't do this is that the "blocked" logic will use forward or reverse mode as appropriate, while "jvp" usually implies forward mode only, so if we use reverse mode when doing a jvp that might cause some issues in places where we assume jvp/vjp mean forward/reverse only. This isn't a dealbreaker, just something we need to think about a bit more.

f0uriest avatar Sep 07 '24 04:09 f0uriest

Some laptop benchmarking on the ATF jacobian benchmark example

image image

The effect on compute time is more thane expected, I think I will implement something that only automatically chooses a smaller size if the expected memory is larger than the available device memory. I will re-do these tests on the gpu if I can as well, this was just on cpu on mac

dpanici avatar Sep 18 '24 13:09 dpanici

I will re-do these tests on the gpu if I can as well, this was just on cpu on mac

Can you also add the benchmarks for jnp.vectorize? To be sure that calling batched_vectorize takes the same time as the original JAX one.

YigitElma avatar Sep 18 '24 16:09 YigitElma

I think I will implement something that only automatically chooses a smaller size if the expected memory is larger than the available device memory.

We can do this based on predefined values for dim_x

YigitElma avatar Sep 18 '24 16:09 YigitElma

I will re-do these tests on the gpu if I can as well, this was just on cpu on mac

Can you also add the benchmarks for jnp.vectorize? To be sure that calling batched_vectorize takes the same time as the original JAX one. yea they are basically the same as having a chunk size equal to dim_x

image image

dpanici avatar Sep 20 '24 13:09 dpanici

based off these scalings for memory usage of jacobian on GPU and CPU for the ATF benchmark, I will have "auto" estimate mem usage based off of dim_x * dim_f and apply a chunk_size to bring the estimated usage down to below the device size

image image

dpanici avatar Sep 20 '24 22:09 dpanici

a conservative estimate for the actual peak memory usage over the estimated peak memory usage according to the above formula, as a function of normalized chunk_size image

dpanici avatar Sep 20 '24 22:09 dpanici

Putting it together, we have something like

chunk_size < (device_mem / estimated_mem - b) / a * dim_x

where a ~ 0.8 and b ~ 0.15, and device_mem is the current host device available memory (as given by desc_config.get("avail_mem") for example)

when the RHS of this inequality is <1, that means we estimate that we will get an OOM error as even with a chunk size of 1, we cannot fit the jacobian's fwd mode calculation in the device memory

dpanici avatar Sep 20 '24 22:09 dpanici

@dpanici Hi, It is irrelevant to this discussion, but what script you used to create the plot pasted on top of the PR?

Qazalbash avatar Feb 18 '25 08:02 Qazalbash

@dpanici Hi, It is irrelevant to this discussion, but what script you used to create the plot pasted on top of the PR?

Hi @Qazalbash , Sorry for the delayed response. I will share the script in 1 or 2 days (this is a self-reminder😄)

In the mean time, either @rahulgaur104 or @dpanici if you have the script in your archieve, can you share it? It was the one that used pynvml I guess.

YigitElma avatar Feb 22 '25 20:02 YigitElma

Hi @Qazalbash

This should do the job, more or less. Replace DESC_profiler_new_13-04.py with the script that you are running. Good luck!

#!/usr/bin/env python3
import subprocess
import time
import threading
import matplotlib.pyplot as plt

def monitor_vram(duration, interval, vram_usage_list, timestamps):
    end_time = time.time() + duration
    while time.time() < end_time:
        result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used', '--format=csv,noheader,nounits'], 
                                stdout=subprocess.PIPE)
        output = result.stdout.decode('utf-8').strip()
        vram_usage = int(output.split()[0])
        vram_usage_list.append(vram_usage)
        timestamps.append(time.time())
        time.sleep(interval)


if __name__ == "__main__":
    duration = 300  #duration to monitor in seconds
    interval = 0.01  #interval between checks in seconds
    vram_usage_list = []
    timestamps = []

    #create threads for monitoring VRAM and running GPU code
    vram_thread = threading.Thread(target=monitor_vram, args=(duration, interval, vram_usage_list, timestamps))

    #start the threads
    vram_thread.start()

    res = 12
    #run without blocking (Popen)
    subprocess.Popen(['python', "DESC_profiler_new_13-04.py", f"{res}"])
    #subprocess.Popen(['python', "DESC_profiler_new_06-04.py", f"{res}"])

    #wait for the thread to finish
    vram_thread.join()

    #write the VRAM usage to a file
    with open('vram_usage.txt', 'w') as file:
        for usage in vram_usage_list:
            file.write(f"{usage}\n")

    plt.figure(figsize=(15, 7))
    #plot the VRAM usage
    plt.plot([t - timestamps[0] for t in timestamps], vram_usage_list, '-or', ms=2)
    plt.xlabel('Time (s)', fontsize=20)
    plt.ylabel('VRAM Usage (MiB)', fontsize=20)
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.title('GPU VRAM Usage Over Time')
    plt.grid(True)
    plt.tight_layout()
    #plt.savefig(f"test0_QP_{res}.png", dpi=400)
    plt.savefig(f"test0_QP_{res}_lowp.png", dpi=400)
    #plt.show()

rahulgaur104 avatar Feb 23 '25 19:02 rahulgaur104

Thanks @YigitElma and @rahulgaur104

Qazalbash avatar Feb 23 '25 19:02 Qazalbash

You're welcome! Also, remember that you have to set the environment variable

XLA_PYTHON_CLIENT_ALLOCATOR=platform

to prevent JAX from allocating 75% of your memory, otherwise, you will get a flat line.

PS: On command line, for setting an environment variable you should use export like

export XLA_PYTHON_CLIENT_ALLOCATOR=platform

to unset it, you should use

unset XLA_PYTHON_CLIENT_ALLOCATOR

YigitElma avatar Feb 23 '25 20:02 YigitElma

I really appreciate that! I will mention these in the docs.

Qazalbash avatar Feb 23 '25 23:02 Qazalbash