pymc icon indicating copy to clipboard operation
pymc copied to clipboard

BUG: Jax-based samplers crash at transformation stage

Open fonnesbeck opened this issue 2 years ago • 9 comments

Describe the issue:

The Jax-based samplers crash after sampling, following the "Transforming variables..." message on medium-to-large models (thousands of rows, hundreds of parameters). This occurs both on GPU and CPU systems, and using either the numpyro or blackjax samplers. The failure on GPU returns a backtrace that isolates the issue at the vmap in _postprocess_samples. On a CPU (MacBook Pro M1), the process is simply killed without any error messages. I have tried running the GPU model with the postprocessing_backend="cpu" argument for the numpyro sampler, but this does not seem to make a difference. Should it be using vmap when the postprocessing backend is CPU?

Reproduceable code example:

Will add example when I can come up with one

Error message:

CPU machine error:


Compilation time =  0:00:09.225151
Sampling...
Running chain 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [12:00:33<00:00, 21.62s/it]
Running chain 1: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [12:00:33<00:00, 21.62s/it]
Sampling time =  12:00:35.215191
Transforming variables...
Killed: 9
/Users/cfonnesbeck/mambaforge/envs/pymc/lib/python3.11/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

PyMC version information:

PyMC 5.3.0 PyTensor 2.11.1

Context for the issue:

The numpyro sampler is currently unusable for moderate-sized models due to this issue.

fonnesbeck avatar May 30 '23 12:05 fonnesbeck

Setting postprocessing_chunks to somewhat large values (~10) seems to prevent this, since it appears to be an issue with vmap.

fonnesbeck avatar Jun 07 '23 14:06 fonnesbeck

I think this was solved by switching to scan as the default

ricardoV94 avatar Nov 25 '23 07:11 ricardoV94

I'm still getting out of memory crashes after sampling even when using v5.10. Is it still possible to set postprocessing_chunks? It seemed to work previously.

fonnesbeck avatar Dec 12 '23 02:12 fonnesbeck

The options are now scan or vmap, scan is the default which is more memory conscious: https://github.com/pymc-devs/pymc/blob/c53277bee7cd2bdbbc82e349c420d2926a9d1140/pymc/sampling/jax.py#L188

ricardoV94 avatar Dec 12 '23 08:12 ricardoV94

Yeah, I saw that. I still get crashes post-processing on GPU for large models (even with postprocessing_backend="cpu").

fonnesbeck avatar Dec 12 '23 13:12 fonnesbeck

This looks like it might help, though it is not implemented in Jax yet. We should probably keep the option for using xmap in the interim.

fonnesbeck avatar Dec 12 '23 13:12 fonnesbeck

We are already using Scan by default, so I don't think it would help

ricardoV94 avatar Dec 12 '23 13:12 ricardoV94

I'm running into the same OOM issue in post-processing with the default postprocessing_vectorize="scan" . Is postprocessing_chunks not something that can brought back as an experimental, use at your own risk, parameter?

JasonTam avatar Jan 22 '24 01:01 JasonTam

IIRC postprocessing_chunks is just using scan under the hood anyway, so it shouldn't help. Can you check it actually helps in your case?

We need an example to investigate this issue, but if you see a difference we can consider temporarily reverting while we figure it out

ricardoV94 avatar Jan 22 '24 01:01 ricardoV94