pymc
pymc copied to clipboard
BUG: Jax-based samplers crash at transformation stage
Describe the issue:
The Jax-based samplers crash after sampling, following the "Transforming variables..." message on medium-to-large models (thousands of rows, hundreds of parameters). This occurs both on GPU and CPU systems, and using either the numpyro or blackjax samplers. The failure on GPU returns a backtrace that isolates the issue at the vmap in _postprocess_samples. On a CPU (MacBook Pro M1), the process is simply killed without any error messages. I have tried running the GPU model with the postprocessing_backend="cpu" argument for the numpyro sampler, but this does not seem to make a difference. Should it be using vmap when the postprocessing backend is CPU?
Reproduceable code example:
Will add example when I can come up with one
Error message:
CPU machine error:
Compilation time = 0:00:09.225151
Sampling...
Running chain 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [12:00:33<00:00, 21.62s/it]
Running chain 1: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [12:00:33<00:00, 21.62s/it]
Sampling time = 12:00:35.215191
Transforming variables...
Killed: 9
/Users/cfonnesbeck/mambaforge/envs/pymc/lib/python3.11/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '
PyMC version information:
PyMC 5.3.0 PyTensor 2.11.1
Context for the issue:
The numpyro sampler is currently unusable for moderate-sized models due to this issue.
Setting postprocessing_chunks to somewhat large values (~10) seems to prevent this, since it appears to be an issue with vmap.
I think this was solved by switching to scan as the default
I'm still getting out of memory crashes after sampling even when using v5.10. Is it still possible to set postprocessing_chunks? It seemed to work previously.
The options are now scan or vmap, scan is the default which is more memory conscious: https://github.com/pymc-devs/pymc/blob/c53277bee7cd2bdbbc82e349c420d2926a9d1140/pymc/sampling/jax.py#L188
Yeah, I saw that. I still get crashes post-processing on GPU for large models (even with postprocessing_backend="cpu").
This looks like it might help, though it is not implemented in Jax yet. We should probably keep the option for using xmap in the interim.
We are already using Scan by default, so I don't think it would help
I'm running into the same OOM issue in post-processing with the default postprocessing_vectorize="scan" .
Is postprocessing_chunks not something that can brought back as an experimental, use at your own risk, parameter?
IIRC postprocessing_chunks is just using scan under the hood anyway, so it shouldn't help. Can you check it actually helps in your case?
We need an example to investigate this issue, but if you see a difference we can consider temporarily reverting while we figure it out