UltraNest
UltraNest copied to clipboard
Vectorised sampling and memory consumption
- UltraNest version: 3.6.4
- Python version: 3.10.13
- Operating System: MacOS 14.1.1
- Jax version: 0.4.13
Description
I am using ultranest to sample from a posterior which is essentially a multidimensional gaussian (~20-30 parameters) defined in jax and jax.jit compiled. The code is working well when I do not have vectorisation. When I turn it on, it gets much faster but I observe a rather severe increase in memory consumption when I run with 20 parameters (around 14 Gb). When I increase the dimensionality of the problem to 30 parameters, the memory consumption becomes excessive (around 60 Gb).
None of this happens when vectorised=False.
What I Did
I tried setting a maximum number of draws with ndraw_max = 500. The issue seems tamed, but still the memory consumption grows with the number of likelihood calls (around 30 Gb).
sampler = ultranest.ReactiveNestedSampler(
parameters,
log_likelihood_vectorised,
weight_minimization_prior,
vectorized=True,
ndraw_max=500
)