pymc icon indicating copy to clipboard operation
pymc copied to clipboard

Remove swapaxes before and after scan

Open JasonTam opened this issue 1 year ago • 17 comments

Description

Currently, the results of scan are evaluated in _postprocess_samples, and then the axes are fixed in the list comprehension [jnp.swapaxes(t, 0, 1) for _, t in outs]. This seems to unnecessarily double the peak memory footprint of this method. Admittedly, I don't know much about scan and the jaxified function, but it seems that the we may not need to transpose before and after. From what I gather, it doesnt matter if the in/out are of dimension (chains, draws, ...) or (draws, chains, ...). Avoiding the final transpose in the list comp should lower the peak memory footprint by about half (?) In my testing, outputs were exactly the same after omitting the double transpose.

(But if the axis swaps are indeed necessary, maybe the operations can still be combined in a way that avoids the list comp at the end.)

Memory usage tested with the following:

import pickle
from pathlib import Path

import jax
import jax.numpy as jnp
import pytest
from jax.experimental.maps import SerialLoop, xmap
from jax.lax import scan
from pymc.sampling.jax import _device_put, get_jaxified_graph

CUR_DIR = Path(__file__).parent

DIR_FIXTURES = CUR_DIR / "../fixtures/profiling"

PATH_MODEL = DIR_FIXTURES / "model_pm.p"
PATH_DATA = DIR_FIXTURES / "raw_mcmc_samples.p"

postprocessing_backend = None


@pytest.fixture
def model():
    return pickle.load(open(PATH_MODEL, "rb"))


@pytest.fixture
def raw_mcmc_samples():
    return pickle.load(open(PATH_DATA, "rb"))


def get_jax_fn(model):
    vars_to_sample = [
        v for v in model.unobserved_value_vars if not v.name.endswith("__")
    ]
    jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
    return jax_fn


def test_scan_vmap(model, raw_mcmc_samples):
    jax_fn = get_jax_fn(model)

    t_raw_mcmc_samples = [jnp.swapaxes(t, 0, 1) for t in raw_mcmc_samples]
    jax_vfn = jax.vmap(jax_fn)
    _, outs = scan(
        lambda _, x: ((), jax_vfn(*x)),
        (),
        _device_put(t_raw_mcmc_samples, postprocessing_backend),
    )
    ret = [jnp.swapaxes(t, 0, 1) for t in outs]


def test_scan_vmap_wo_transpose(model, raw_mcmc_samples):
    jax_fn = get_jax_fn(model)

    jax_vfn = jax.vmap(jax_fn)
    _, outs = scan(
        lambda _, x: ((), jax_vfn(*x)),
        (),
        _device_put(raw_mcmc_samples, postprocessing_backend),
    )
    ret = outs


def test_nested_vmap(model, raw_mcmc_samples):
    jax_fn = get_jax_fn(model)
    ret = jax.vmap(jax.vmap(jax_fn))(
        *_device_put(raw_mcmc_samples, postprocessing_backend)
    )


def test_looped_vmap(model, raw_mcmc_samples, num_chunks=100):
    # https://discourse.pymc.io/t/nameerror-unbound-axis-name-raised-during-transformation-of-variables-after-sample-numpyro-nuts/11167/5
    jax_fn = get_jax_fn(model)

    # dims are vars, chains, draws, ...
    raw_mcmc_samples = _device_put(raw_mcmc_samples, postprocessing_backend)
    f = jax.vmap(jax.vmap(jax_fn))
    draws = len(raw_mcmc_samples[0][0])
    segs = list(range(0, draws, draws // num_chunks)) + [draws]
    # dims are chunks, vars, chains, draws, ...
    outputs = [
        f(*[var_samples[:, i:j] for var_samples in raw_mcmc_samples])
        for i, j in zip(segs[:-1], segs[1:])
    ]
    # dims of var_chunks are chunks, chains, draws, ...
    ret = [jnp.concatenate(var_chunks, axis=1) for var_chunks in zip(*outputs)]

(Note: I couldn't get the legacy chunked xmap method to work -- ran into some jax issue I couldn't decipher)

With the following results using memray:

========================================================================= MEMRAY REPORT =========================================================================
Allocation results for src/tests/profiling/test_jax_postproc_profile.py::test_looped_vmap at the high watermark

         📦 Total memory allocated: 3.4GiB
         📏 Total allocations: 5272432
         📊 Histogram of allocation sizes: |  █ ▁    |
         🥇 Biggest allocating functions:
                - <lambda>:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/dispatch.py:164 -> 1.3GiB
                - __call__:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 761.9MiB
                - _pjit_call_impl:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/pjit.py:1214 -> 653.7MiB
                - _pjit_call_impl:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/pjit.py:1214 -> 653.7MiB
                - _pjit_call_impl:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/pjit.py:1214 -> 90.3MiB


Allocation results for src/tests/profiling/test_jax_postproc_profile.py::test_nested_vmap at the high watermark

         📦 Total memory allocated: 2.9GiB
         📏 Total allocations: 1657858
         📊 Histogram of allocation sizes: |  █ ▁    |
         🥇 Biggest allocating functions:
                - _pjit_call_impl:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/pjit.py:1214 -> 653.4MiB
                - __call__:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 653.4MiB
                - __call__:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 653.4MiB
                - __call__:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 653.4MiB
                - __call__:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 163.3MiB


Allocation results for src/tests/profiling/test_jax_postproc_profile.py::test_scan_vmap at the high watermark

         📦 Total memory allocated: 2.8GiB
         📏 Total allocations: 1416835
         📊 Histogram of allocation sizes: |  █ ▁    |
         🥇 Biggest allocating functions:
                - __call__:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 1.4GiB
                - __call__:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 761.9MiB
                - <listcomp>:/Users/jason/Wonder/  -ds/src/tests/profiling/test_jax_postproc_profile.py:50 -> 653.4MiB
                - __call__:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 18.6MiB
                - <lambda>:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/dispatch.py:164 -> 4.1MiB


Allocation results for src/tests/profiling/test_jax_postproc_profile.py::test_scan_vmap_wo_transpose at the high watermark

         📦 Total memory allocated: 1.6GiB
         📏 Total allocations: 605385
         📊 Histogram of allocation sizes: |▁ █ ▁    |
         🥇 Biggest allocating functions:
                - __call__:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 1.6GiB
                - <lambda>:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/dispatch.py:164 -> 6.6MiB
                - __call__:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/pytensor/link/c/basic.py:1767 -> 3.3MiB
                - backend_compile:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/compiler.py:251 -> 1.6MiB
                - __init__:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/pytensor/graph/rewriting/basic.py:2277 -> 1.0MiB

(Notice: 2.8GiB for test_scan_vmap and 1.6GiB for test_scan_vmap_wo_transpose)

Aside: I originally sought to bring back some notion of a n_chunks param to tradeoff runtime vs peak memory. But I guess that didn't really work out. Even at half the memory footprint, _postprocess_samples seems very peaky.

Related Issue

  • [x] Closes #6744

Checklist

Type of change

  • [x] New feature / enhancement
  • [ ] Bug fix
  • [ ] Documentation
  • [ ] Maintenance
  • [ ] Other (please specify):

📚 Documentation preview 📚: https://pymc--7116.org.readthedocs.build/en/7116/

JasonTam avatar Jan 24 '24 05:01 JasonTam

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Comparison is base (2da4050) 92.21% compared to head (3ab2863) 91.79%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7116      +/-   ##
==========================================
- Coverage   92.21%   91.79%   -0.43%     
==========================================
  Files         101      101              
  Lines       16901    16900       -1     
==========================================
- Hits        15586    15514      -72     
- Misses       1315     1386      +71     
Files Coverage Δ
pymc/sampling/jax.py 93.05% <100.00%> (-0.03%) :arrow_down:

... and 3 files with indirect coverage changes

codecov[bot] avatar Jan 24 '24 06:01 codecov[bot]

From what I gather, it doesnt matter if the in/out are of dimension (chains, draws, ...) or (draws, chains, ...). Avoiding the final transpose in the list comp should lower the peak memory footprint by about half (?)

My understanding is that the transpose is there so scan iterates over draws (usually 1k) instead of chains (usually 4), otherwise there's little difference between the scan and vmap option.

It may however be needed to jit this function so JAX avoids duplicating memory. This didn't seem relevant in the original vmap branch.

ricardoV94 avatar Jan 24 '24 12:01 ricardoV94

Ah, I knew it was there for a reason, just couldn't figure out why. Now it makes a bit more sense. Then maybe https://github.com/google/jax/issues/2509 would be helpful, but that issue's been stale for quite some time.

JasonTam avatar Jan 24 '24 19:01 JasonTam

Some more variants to throw into the bake-off I was curious to see:

  1. What if we switched what scan / vmap were being used for. ie) scan over draws and vmap over chains
  2. What if we used scan for both chains and draws

(2) showed a smaller footprint (1.8GiB), but still not as small as what's in the PR (scan for chains and vmap for draws -- which used 1.6GiB)

I'm also not 100% confident of my testing methodology of using memray, and a smaller trace with only 100 draws. But FWIW, this change does allow my large model to finish sampling now with a 30gb memory limit when it was previously OOM under a 48gb limit.

def test_vmap_scan(model, raw_mcmc_samples):
    jax_fn = get_jax_fn(model)

    def scan_over_draws(*x):
        _, outs = scan(
            f=lambda _, xx: ((), jax_fn(*xx)),
            init=(),
            xs=x,
        )
        return outs

    final_fn = jax.vmap(
        fun=scan_over_draws,
        in_axes=0,  # chains
        out_axes=0,  # output it back as the leading axis
    )

    ret = final_fn(*_device_put(raw_mcmc_samples, postprocessing_backend))


def test_scan_scan(model, raw_mcmc_samples):
    jax_fn = get_jax_fn(model)
    
    def scan_over_draws(*x):
        _, outs = scan(
            f=lambda _, xx: ((), jax_fn(*xx)),
            init=(),
            xs=x,
        )
        return outs

    def scan_over_chains(*x):
        _, outs = scan(
            f=lambda _, xx: ((), scan_over_draws(*xx)),
            init=(),
            xs=x,
        )
        return outs

    ret = scan_over_chains(*_device_put(raw_mcmc_samples, postprocessing_backend))

With the following memory footprints:

================================================== MEMRAY REPORT ==================================================
Allocation results for src/tests/profiling/test_jax_postproc_profile.py::test_vmap_scan at the high watermark

         📦 Total memory allocated: 2.8GiB
         📏 Total allocations: 1451992
         📊 Histogram of allocation sizes: |  █ ▁    |
         🥇 Biggest allocating functions:
                - __call__:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 1.4GiB
                - __call__:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 761.9MiB
                - <lambda>:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/dispatch.py:164 -> 653.4MiB
                - refresh:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/pytensor/link/c/cmodule.py:851 -> 6.0MiB
                - <lambda>:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/dispatch.py:164 -> 4.1MiB


Allocation results for src/tests/profiling/test_jax_postproc_profile.py::test_scan_scan at the high watermark

         📦 Total memory allocated: 1.8GiB
         📏 Total allocations: 960009
         📊 Histogram of allocation sizes: |▁ █      |
         🥇 Biggest allocating functions:
                - __call__:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 1.7GiB
                - <lambda>:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/dispatch.py:164 -> 6.6MiB
                - __call__:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/pytensor/link/c/basic.py:1767 -> 3.3MiB
                - backend_compile:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/compiler.py:251 -> 2.4MiB
                - no_nan:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/pytensor/tensor/variable.py:1023 -> 1.6MiB

JasonTam avatar Jan 24 '24 20:01 JasonTam

Did you try jitting? Does it change anything for even a single use like this?

ricardoV94 avatar Jan 24 '24 21:01 ricardoV94

Regarding the best option. We added this because we were getting OOM with everything vmapped.

You may want to check if your OOM is related to jax pre allocating too much? There's a config flag for that

ricardoV94 avatar Jan 24 '24 21:01 ricardoV94

Did you try jitting? Does it change anything for even a single use like this?

I'm not exactly sure which function would be jit compiled. Here I'm trying the whole thing: (1.7Gib, slightly more than the non-jit fn at 1.6Gib)

def test_jit_scan_vmap_wo_transpose(model, raw_mcmc_samples):
    jax_fn = get_jax_fn(model)

    def raw_fn():
        jax_vfn = jax.vmap(jax_fn)
        _, outs = scan(
            lambda _, x: ((), jax_vfn(*x)),
            (),
            _device_put(raw_mcmc_samples, postprocessing_backend),
        )
        return outs
    jit_fn = jax.jit(raw_fn)
    ret = jit_fn()
Allocation results for src/tests/profiling/test_jax_postproc_profile.py::test_jit_scan_vmap_wo_transpose at the high watermark

         📦 Total memory allocated: 1.7GiB
         📏 Total allocations: 523197
         📊 Histogram of allocation sizes: |▁ █ ▁    |
         🥇 Biggest allocating functions:
                - __call__:/Users/jason/mambaforge/envs/whisk/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 1.5GiB
                - backend_compile:/Users/jason/mambaforge/envs/whisk/lib/python3.10/site-packages/jax/_src/compiler.py:251 -> 88.0MiB
                - _numpy_array_constant:/Users/jason/mambaforge/envs/whisk/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:252 -> 25.1MiB
                - <lambda>:/Users/jason/mambaforge/envs/whisk/lib/python3.10/site-packages/jax/_src/dispatch.py:164 -> 6.6MiB
                - __call__:/Users/jason/mambaforge/envs/whisk/lib/python3.10/site-packages/pytensor/link/c/basic.py:1767 -> 3.3MiB

Regarding the best option. We added this because we were getting OOM with everything vmapped.

You may want to check if your OOM is related to jax pre allocating too much? There's a config flag for that

Oh, I haven't read into jax's preallocation at all. What's the config flag you're referring to?

JasonTam avatar Jan 24 '24 22:01 JasonTam

Oh, I haven't read into jax's preallocation at all. What's the config flag you're referring to?

Seems to only matter for GPU https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

ricardoV94 avatar Jan 25 '24 04:01 ricardoV94

(2) showed a smaller footprint (1.8GiB), but still not as small as what's in the PR (scan for chains and vmap for draws -- which used 1.6GiB)

I'm surprised by this but I guess JAX traced-naive transpose+scan just sucks memory wise (the docs of jax say swapaxes may return copies which I guess it is doing in your case).

The problem is that I don't know if this is general. @fonnesbeck could you test if this also fixes the memory problems you were seeing in your model?

ricardoV94 avatar Jan 25 '24 04:01 ricardoV94

@JasonTam do you want to try nested scan as well? Since you already did so many permutations :b

That should be the most extreme at tbe opposite side of just vmap

ricardoV94 avatar Jan 25 '24 04:01 ricardoV94

@JasonTam do you want to try nested scan as well? Since you already did so many permutations :b

That should be the most extreme at tbe opposite side of just vmap

Nested scan is the implementation (2) in

I was curious to see:

  1. What if we switched what scan / vmap were being used for. ie) scan over draws and vmap over chains
  2. What if we used scan for both chains and draws

(2) showed a smaller footprint (1.8GiB), but still not as small as what's in the PR (scan for chains and vmap for draws -- which used 1.6GiB)

where scan_over_chains calls scan_over_draws

JasonTam avatar Jan 25 '24 05:01 JasonTam

Thanks @JasonTam I missed it.

Due to the fear of over fitting to one example I would perhaps go for nested Scan? WDYT?

ricardoV94 avatar Jan 25 '24 12:01 ricardoV94

Due to the fear of over fitting to one example I would perhaps go for nested Scan? WDYT?

I too am afraid of my test set-up not generalizing well. I'm going to try to run some more tests. But also, since this is probably a widely used function for most users, I'd consider putting it under another option in postprocessing_vectorize: Literal["vmap", "scan", "nested_scan"], and potentially keeping the default as "scan".

I also think the existing options "vmap" and "scan" are a little misleading since both use vmap over the chain dimension . Feels more like "nested_vmap", "vmap_scan", etc

I would definitely also appreciate feedback from @ferrine, as the previous author these bits

JasonTam avatar Jan 25 '24 16:01 JasonTam

@JasonTam any news? Would be great to patch this one up :)

ricardoV94 avatar Feb 06 '24 09:02 ricardoV94

@ricardoV94 I haven't had time to play with this unfortunately. The only news I have is: I'm was not able to reproduce this earlier claim of mine:

But FWIW, this change does allow my large model to finish sampling now with a 30gb memory limit when it was previously OOM under a 48gb limit.

When testing this larger model, some of these methods were failing. But I need to wait my turn on a cluster to make sure there's no interference, so testing these methods has been slow.

JasonTam avatar Feb 06 '24 21:02 JasonTam

@ricardoV94 here are some results from a larger test: raw_mcm_samples looks like:

 [
(4, 1000)
(4, 1000)
(4, 1000, 468)
(4, 1000, 7, 468)
(4, 1000, 2, 9, 468)
]

5 variables of (4 chains, 1000 draws, ...)

Tested on an azure k8s cluster with Epdsv5-series vm's (3.0Ghz cpu) where each job has plenty of cpu and memory to spare.

method name method for chains dim method for draws dim Peak Memory [GiB] Total allocations Call Duration [s]
scan_vmap via listcomp transpose (control) vmap scan 28.765 1451658 8.219
scan_vmap_wo_transpose scan vmap 15.938 828079 6.579
vmap_scan vmap scan 28.532 1452441 8.318
nested_scan scan scan 17.903 1277702 8.293
nested_vmap vmap vmap 30.123 1827578 9.175
vmap_map vmap map 28.532 1451486 8.280
nested_xmap xmap xmap 20.703 623790 5.592
looped_vmap python loop vmap 34.933 6474397 20.379

I hope I'm understanding which method goes to which dimension correct. From these results, it does seem like scan_vmap_wo_transpose has the lowest peak memory. (Which is simply removing the transposes via list comp, as seen in this PR)

JasonTam avatar Feb 10 '24 23:02 JasonTam

I still like the nested scan better because we know that vmap was the source of the problem in the case that first motivated these changes. Unfortunately we don't have a way to retrieve that example but I suspect the current solution in this PR would be a regression there .

ricardoV94 avatar Feb 11 '24 08:02 ricardoV94