UltraNest icon indicating copy to clipboard operation
UltraNest copied to clipboard

Vectorised sampling and memory consumption

Open LucaMantani opened this issue 7 months ago • 4 comments

  • UltraNest version: 3.6.4
  • Python version: 3.10.13
  • Operating System: MacOS 14.1.1
  • Jax version: 0.4.13

Description

I am using ultranest to sample from a posterior which is essentially a multidimensional gaussian (~20-30 parameters) defined in jax and jax.jit compiled. The code is working well when I do not have vectorisation. When I turn it on, it gets much faster but I observe a rather severe increase in memory consumption when I run with 20 parameters (around 14 Gb). When I increase the dimensionality of the problem to 30 parameters, the memory consumption becomes excessive (around 60 Gb).

None of this happens when vectorised=False.

What I Did

I tried setting a maximum number of draws with ndraw_max = 500. The issue seems tamed, but still the memory consumption grows with the number of likelihood calls (around 30 Gb).

sampler = ultranest.ReactiveNestedSampler(
            parameters,
            log_likelihood_vectorised,
            weight_minimization_prior,
            vectorized=True,
            ndraw_max=500
        )

LucaMantani avatar Nov 23 '23 16:11 LucaMantani