pymc
pymc copied to clipboard
BUG: kernel crashes on GPU during transformation stage
Describe the issue:
I'm experiencing chronic issues with kernel crashes at the end of sampling during the variable transformation stage for large models. This is despite using the postprocessing_backend="cpu" option in pm.sample. In the past I've been able to get around this by specifying postprocessing_chunks, but this is no longer an option. I've tried as much as possible to keep the number of deterministics to a minimum, but some are unavoidable. It seems like this should not be an issue if the CPU is indeed being used for postprocessing (which leads me to suspect that the CPU is not being used).
Reproduceable code example:
(difficult to generate a reproducible example, as it is more a function of the data and platform than the model, but can provide a model upon request)
Error message:
Transforming variables...
INFO:pymc.sampling.jax:Transforming variables...
The Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click here for more info. View Jupyter log for further details.
PyMC version information:
PyMC: 5.10.3 PyTensor: 2.18.4
Context for the issue:
No response