pymc
pymc copied to clipboard
Remove swapaxes before and after scan
Description
Currently, the results of scan are evaluated in _postprocess_samples, and then the axes are fixed in the list comprehension [jnp.swapaxes(t, 0, 1) for _, t in outs]. This seems to unnecessarily double the peak memory footprint of this method. Admittedly, I don't know much about scan and the jaxified function, but it seems that the we may not need to transpose before and after.
From what I gather, it doesnt matter if the in/out are of dimension (chains, draws, ...) or (draws, chains, ...). Avoiding the final transpose in the list comp should lower the peak memory footprint by about half (?)
In my testing, outputs were exactly the same after omitting the double transpose.
(But if the axis swaps are indeed necessary, maybe the operations can still be combined in a way that avoids the list comp at the end.)
Memory usage tested with the following:
import pickle
from pathlib import Path
import jax
import jax.numpy as jnp
import pytest
from jax.experimental.maps import SerialLoop, xmap
from jax.lax import scan
from pymc.sampling.jax import _device_put, get_jaxified_graph
CUR_DIR = Path(__file__).parent
DIR_FIXTURES = CUR_DIR / "../fixtures/profiling"
PATH_MODEL = DIR_FIXTURES / "model_pm.p"
PATH_DATA = DIR_FIXTURES / "raw_mcmc_samples.p"
postprocessing_backend = None
@pytest.fixture
def model():
return pickle.load(open(PATH_MODEL, "rb"))
@pytest.fixture
def raw_mcmc_samples():
return pickle.load(open(PATH_DATA, "rb"))
def get_jax_fn(model):
vars_to_sample = [
v for v in model.unobserved_value_vars if not v.name.endswith("__")
]
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
return jax_fn
def test_scan_vmap(model, raw_mcmc_samples):
jax_fn = get_jax_fn(model)
t_raw_mcmc_samples = [jnp.swapaxes(t, 0, 1) for t in raw_mcmc_samples]
jax_vfn = jax.vmap(jax_fn)
_, outs = scan(
lambda _, x: ((), jax_vfn(*x)),
(),
_device_put(t_raw_mcmc_samples, postprocessing_backend),
)
ret = [jnp.swapaxes(t, 0, 1) for t in outs]
def test_scan_vmap_wo_transpose(model, raw_mcmc_samples):
jax_fn = get_jax_fn(model)
jax_vfn = jax.vmap(jax_fn)
_, outs = scan(
lambda _, x: ((), jax_vfn(*x)),
(),
_device_put(raw_mcmc_samples, postprocessing_backend),
)
ret = outs
def test_nested_vmap(model, raw_mcmc_samples):
jax_fn = get_jax_fn(model)
ret = jax.vmap(jax.vmap(jax_fn))(
*_device_put(raw_mcmc_samples, postprocessing_backend)
)
def test_looped_vmap(model, raw_mcmc_samples, num_chunks=100):
# https://discourse.pymc.io/t/nameerror-unbound-axis-name-raised-during-transformation-of-variables-after-sample-numpyro-nuts/11167/5
jax_fn = get_jax_fn(model)
# dims are vars, chains, draws, ...
raw_mcmc_samples = _device_put(raw_mcmc_samples, postprocessing_backend)
f = jax.vmap(jax.vmap(jax_fn))
draws = len(raw_mcmc_samples[0][0])
segs = list(range(0, draws, draws // num_chunks)) + [draws]
# dims are chunks, vars, chains, draws, ...
outputs = [
f(*[var_samples[:, i:j] for var_samples in raw_mcmc_samples])
for i, j in zip(segs[:-1], segs[1:])
]
# dims of var_chunks are chunks, chains, draws, ...
ret = [jnp.concatenate(var_chunks, axis=1) for var_chunks in zip(*outputs)]
(Note: I couldn't get the legacy chunked xmap method to work -- ran into some jax issue I couldn't decipher)
With the following results using memray:
========================================================================= MEMRAY REPORT =========================================================================
Allocation results for src/tests/profiling/test_jax_postproc_profile.py::test_looped_vmap at the high watermark
📦 Total memory allocated: 3.4GiB
📏 Total allocations: 5272432
📊 Histogram of allocation sizes: | █ ▁ |
🥇 Biggest allocating functions:
- <lambda>:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/dispatch.py:164 -> 1.3GiB
- __call__:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 761.9MiB
- _pjit_call_impl:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/pjit.py:1214 -> 653.7MiB
- _pjit_call_impl:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/pjit.py:1214 -> 653.7MiB
- _pjit_call_impl:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/pjit.py:1214 -> 90.3MiB
Allocation results for src/tests/profiling/test_jax_postproc_profile.py::test_nested_vmap at the high watermark
📦 Total memory allocated: 2.9GiB
📏 Total allocations: 1657858
📊 Histogram of allocation sizes: | █ ▁ |
🥇 Biggest allocating functions:
- _pjit_call_impl:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/pjit.py:1214 -> 653.4MiB
- __call__:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 653.4MiB
- __call__:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 653.4MiB
- __call__:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 653.4MiB
- __call__:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 163.3MiB
Allocation results for src/tests/profiling/test_jax_postproc_profile.py::test_scan_vmap at the high watermark
📦 Total memory allocated: 2.8GiB
📏 Total allocations: 1416835
📊 Histogram of allocation sizes: | █ ▁ |
🥇 Biggest allocating functions:
- __call__:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 1.4GiB
- __call__:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 761.9MiB
- <listcomp>:/Users/jason/Wonder/ -ds/src/tests/profiling/test_jax_postproc_profile.py:50 -> 653.4MiB
- __call__:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 18.6MiB
- <lambda>:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/dispatch.py:164 -> 4.1MiB
Allocation results for src/tests/profiling/test_jax_postproc_profile.py::test_scan_vmap_wo_transpose at the high watermark
📦 Total memory allocated: 1.6GiB
📏 Total allocations: 605385
📊 Histogram of allocation sizes: |▁ █ ▁ |
🥇 Biggest allocating functions:
- __call__:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 1.6GiB
- <lambda>:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/dispatch.py:164 -> 6.6MiB
- __call__:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/pytensor/link/c/basic.py:1767 -> 3.3MiB
- backend_compile:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/compiler.py:251 -> 1.6MiB
- __init__:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/pytensor/graph/rewriting/basic.py:2277 -> 1.0MiB
(Notice: 2.8GiB for test_scan_vmap and 1.6GiB for test_scan_vmap_wo_transpose)
Aside: I originally sought to bring back some notion of a n_chunks param to tradeoff runtime vs peak memory. But I guess that didn't really work out. Even at half the memory footprint, _postprocess_samples seems very peaky.
Related Issue
- [x] Closes #6744
Checklist
- [ ] Checked that the pre-commit linting/style checks pass
- [ ] Included tests that prove the fix is effective or that the new feature works
- [ ] Added necessary documentation (docstrings and/or example notebooks)
- [ ] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [x] New feature / enhancement
- [ ] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [ ] Other (please specify):
📚 Documentation preview 📚: https://pymc--7116.org.readthedocs.build/en/7116/
Codecov Report
All modified and coverable lines are covered by tests :white_check_mark:
Comparison is base (
2da4050) 92.21% compared to head (3ab2863) 91.79%.
Additional details and impacted files
@@ Coverage Diff @@
## main #7116 +/- ##
==========================================
- Coverage 92.21% 91.79% -0.43%
==========================================
Files 101 101
Lines 16901 16900 -1
==========================================
- Hits 15586 15514 -72
- Misses 1315 1386 +71
| Files | Coverage Δ | |
|---|---|---|
| pymc/sampling/jax.py | 93.05% <100.00%> (-0.03%) |
:arrow_down: |
From what I gather, it doesnt matter if the in/out are of dimension (chains, draws, ...) or (draws, chains, ...). Avoiding the final transpose in the list comp should lower the peak memory footprint by about half (?)
My understanding is that the transpose is there so scan iterates over draws (usually 1k) instead of chains (usually 4), otherwise there's little difference between the scan and vmap option.
It may however be needed to jit this function so JAX avoids duplicating memory. This didn't seem relevant in the original vmap branch.
Ah, I knew it was there for a reason, just couldn't figure out why. Now it makes a bit more sense. Then maybe https://github.com/google/jax/issues/2509 would be helpful, but that issue's been stale for quite some time.
Some more variants to throw into the bake-off I was curious to see:
- What if we switched what scan / vmap were being used for. ie) scan over draws and vmap over chains
- What if we used scan for both chains and draws
(2) showed a smaller footprint (1.8GiB), but still not as small as what's in the PR (scan for chains and vmap for draws -- which used 1.6GiB)
I'm also not 100% confident of my testing methodology of using memray, and a smaller trace with only 100 draws. But FWIW, this change does allow my large model to finish sampling now with a 30gb memory limit when it was previously OOM under a 48gb limit.
def test_vmap_scan(model, raw_mcmc_samples):
jax_fn = get_jax_fn(model)
def scan_over_draws(*x):
_, outs = scan(
f=lambda _, xx: ((), jax_fn(*xx)),
init=(),
xs=x,
)
return outs
final_fn = jax.vmap(
fun=scan_over_draws,
in_axes=0, # chains
out_axes=0, # output it back as the leading axis
)
ret = final_fn(*_device_put(raw_mcmc_samples, postprocessing_backend))
def test_scan_scan(model, raw_mcmc_samples):
jax_fn = get_jax_fn(model)
def scan_over_draws(*x):
_, outs = scan(
f=lambda _, xx: ((), jax_fn(*xx)),
init=(),
xs=x,
)
return outs
def scan_over_chains(*x):
_, outs = scan(
f=lambda _, xx: ((), scan_over_draws(*xx)),
init=(),
xs=x,
)
return outs
ret = scan_over_chains(*_device_put(raw_mcmc_samples, postprocessing_backend))
With the following memory footprints:
================================================== MEMRAY REPORT ==================================================
Allocation results for src/tests/profiling/test_jax_postproc_profile.py::test_vmap_scan at the high watermark
📦 Total memory allocated: 2.8GiB
📏 Total allocations: 1451992
📊 Histogram of allocation sizes: | █ ▁ |
🥇 Biggest allocating functions:
- __call__:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 1.4GiB
- __call__:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 761.9MiB
- <lambda>:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/dispatch.py:164 -> 653.4MiB
- refresh:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/pytensor/link/c/cmodule.py:851 -> 6.0MiB
- <lambda>:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/dispatch.py:164 -> 4.1MiB
Allocation results for src/tests/profiling/test_jax_postproc_profile.py::test_scan_scan at the high watermark
📦 Total memory allocated: 1.8GiB
📏 Total allocations: 960009
📊 Histogram of allocation sizes: |▁ █ |
🥇 Biggest allocating functions:
- __call__:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 1.7GiB
- <lambda>:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/dispatch.py:164 -> 6.6MiB
- __call__:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/pytensor/link/c/basic.py:1767 -> 3.3MiB
- backend_compile:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/compiler.py:251 -> 2.4MiB
- no_nan:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/pytensor/tensor/variable.py:1023 -> 1.6MiB
Did you try jitting? Does it change anything for even a single use like this?
Regarding the best option. We added this because we were getting OOM with everything vmapped.
You may want to check if your OOM is related to jax pre allocating too much? There's a config flag for that
Did you try jitting? Does it change anything for even a single use like this?
I'm not exactly sure which function would be jit compiled. Here I'm trying the whole thing: (1.7Gib, slightly more than the non-jit fn at 1.6Gib)
def test_jit_scan_vmap_wo_transpose(model, raw_mcmc_samples):
jax_fn = get_jax_fn(model)
def raw_fn():
jax_vfn = jax.vmap(jax_fn)
_, outs = scan(
lambda _, x: ((), jax_vfn(*x)),
(),
_device_put(raw_mcmc_samples, postprocessing_backend),
)
return outs
jit_fn = jax.jit(raw_fn)
ret = jit_fn()
Allocation results for src/tests/profiling/test_jax_postproc_profile.py::test_jit_scan_vmap_wo_transpose at the high watermark
📦 Total memory allocated: 1.7GiB
📏 Total allocations: 523197
📊 Histogram of allocation sizes: |▁ █ ▁ |
🥇 Biggest allocating functions:
- __call__:/Users/jason/mambaforge/envs/whisk/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 1.5GiB
- backend_compile:/Users/jason/mambaforge/envs/whisk/lib/python3.10/site-packages/jax/_src/compiler.py:251 -> 88.0MiB
- _numpy_array_constant:/Users/jason/mambaforge/envs/whisk/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:252 -> 25.1MiB
- <lambda>:/Users/jason/mambaforge/envs/whisk/lib/python3.10/site-packages/jax/_src/dispatch.py:164 -> 6.6MiB
- __call__:/Users/jason/mambaforge/envs/whisk/lib/python3.10/site-packages/pytensor/link/c/basic.py:1767 -> 3.3MiB
Regarding the best option. We added this because we were getting OOM with everything vmapped.
You may want to check if your OOM is related to jax pre allocating too much? There's a config flag for that
Oh, I haven't read into jax's preallocation at all. What's the config flag you're referring to?
Oh, I haven't read into jax's preallocation at all. What's the config flag you're referring to?
Seems to only matter for GPU https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
(2) showed a smaller footprint (1.8GiB), but still not as small as what's in the PR (scan for chains and vmap for draws -- which used 1.6GiB)
I'm surprised by this but I guess JAX traced-naive transpose+scan just sucks memory wise (the docs of jax say swapaxes may return copies which I guess it is doing in your case).
The problem is that I don't know if this is general. @fonnesbeck could you test if this also fixes the memory problems you were seeing in your model?
@JasonTam do you want to try nested scan as well? Since you already did so many permutations :b
That should be the most extreme at tbe opposite side of just vmap
@JasonTam do you want to try nested scan as well? Since you already did so many permutations :b
That should be the most extreme at tbe opposite side of just vmap
Nested scan is the implementation (2) in
I was curious to see:
- What if we switched what scan / vmap were being used for. ie) scan over draws and vmap over chains
- What if we used scan for both chains and draws
(2) showed a smaller footprint (1.8GiB), but still not as small as what's in the PR (scan for chains and vmap for draws -- which used 1.6GiB)
where scan_over_chains calls scan_over_draws
Thanks @JasonTam I missed it.
Due to the fear of over fitting to one example I would perhaps go for nested Scan? WDYT?
Due to the fear of over fitting to one example I would perhaps go for nested Scan? WDYT?
I too am afraid of my test set-up not generalizing well. I'm going to try to run some more tests. But also, since this is probably a widely used function for most users, I'd consider putting it under another option in postprocessing_vectorize: Literal["vmap", "scan", "nested_scan"], and potentially keeping the default as "scan".
I also think the existing options "vmap" and "scan" are a little misleading since both use vmap over the chain dimension . Feels more like "nested_vmap", "vmap_scan", etc
I would definitely also appreciate feedback from @ferrine, as the previous author these bits
@JasonTam any news? Would be great to patch this one up :)
@ricardoV94 I haven't had time to play with this unfortunately. The only news I have is: I'm was not able to reproduce this earlier claim of mine:
But FWIW, this change does allow my large model to finish sampling now with a 30gb memory limit when it was previously OOM under a 48gb limit.
When testing this larger model, some of these methods were failing. But I need to wait my turn on a cluster to make sure there's no interference, so testing these methods has been slow.
@ricardoV94 here are some results from a larger test:
raw_mcm_samples looks like:
[
(4, 1000)
(4, 1000)
(4, 1000, 468)
(4, 1000, 7, 468)
(4, 1000, 2, 9, 468)
]
5 variables of (4 chains, 1000 draws, ...)
Tested on an azure k8s cluster with Epdsv5-series vm's (3.0Ghz cpu) where each job has plenty of cpu and memory to spare.
| method name | method for chains dim | method for draws dim | Peak Memory [GiB] | Total allocations | Call Duration [s] |
|---|---|---|---|---|---|
| scan_vmap via listcomp transpose (control) | vmap | scan | 28.765 | 1451658 | 8.219 |
| scan_vmap_wo_transpose | scan | vmap | 15.938 | 828079 | 6.579 |
| vmap_scan | vmap | scan | 28.532 | 1452441 | 8.318 |
| nested_scan | scan | scan | 17.903 | 1277702 | 8.293 |
| nested_vmap | vmap | vmap | 30.123 | 1827578 | 9.175 |
| vmap_map | vmap | map | 28.532 | 1451486 | 8.280 |
| nested_xmap | xmap | xmap | 20.703 | 623790 | 5.592 |
| looped_vmap | python loop | vmap | 34.933 | 6474397 | 20.379 |
I hope I'm understanding which method goes to which dimension correct. From these results, it does seem like scan_vmap_wo_transpose has the lowest peak memory. (Which is simply removing the transposes via list comp, as seen in this PR)
I still like the nested scan better because we know that vmap was the source of the problem in the case that first motivated these changes. Unfortunately we don't have a way to retrieve that example but I suspect the current solution in this PR would be a regression there .