UltraNest icon indicating copy to clipboard operation
UltraNest copied to clipboard

Vectorised sampling and memory consumption

Open LucaMantani opened this issue 2 years ago • 4 comments
trafficstars

  • UltraNest version: 3.6.4
  • Python version: 3.10.13
  • Operating System: MacOS 14.1.1
  • Jax version: 0.4.13

Description

I am using ultranest to sample from a posterior which is essentially a multidimensional gaussian (~20-30 parameters) defined in jax and jax.jit compiled. The code is working well when I do not have vectorisation. When I turn it on, it gets much faster but I observe a rather severe increase in memory consumption when I run with 20 parameters (around 14 Gb). When I increase the dimensionality of the problem to 30 parameters, the memory consumption becomes excessive (around 60 Gb).

None of this happens when vectorised=False.

What I Did

I tried setting a maximum number of draws with ndraw_max = 500. The issue seems tamed, but still the memory consumption grows with the number of likelihood calls (around 30 Gb).

sampler = ultranest.ReactiveNestedSampler(
            parameters,
            log_likelihood_vectorised,
            weight_minimization_prior,
            vectorized=True,
            ndraw_max=500
        )

LucaMantani avatar Nov 23 '23 16:11 LucaMantani

@JohannesBuchner I observed the same problem. Running with vectorized=True on a multicore machine lead to an excessive use of memory. In my case for a fit with 30 parameters on a multicore machine the memory consumption reached >= 200 Gb.

comane avatar Feb 02 '24 18:02 comane

Isn't that your likelihood though? You can test with something like:

ndraw_max = 500
for i in range(100):
     us = np.random.uniform(size=(ndraw_max, ndim))
     ps = prior_transform(us)
     Ls = log_likelihood(ps)

(from here)

JohannesBuchner avatar Feb 04 '24 08:02 JohannesBuchner

I ran the tests:

ndraw_max = 500
us = np.random.uniform(size=(ndraw_max, 15))

%timeit prior_transform(us)

6.24 µs ± 21.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

ndraw_max = 500
us = np.random.uniform(size=(ndraw_max, 15))

p = prior_transform(us)

%timeit log_likelihood(p)

41.4 ms ± 656 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Then I ran the suggested loop

ndraw_max = 500
for i in range(100000):
    us = np.random.uniform(size=(ndraw_max, 15))
    ps = prior_transform(us)
    Ls = log_likelihood(ps)

Here I do not see problem with memory, the python process needs ~1 GB, and the memory usage does not increase with time, contrary to what I observe when I run ultranest.

LucaMantani avatar Feb 05 '24 09:02 LucaMantani

I guess however that running a vectorized gaussian likelihood with ultranest would not show this extreme memory usage.

So maybe there is a memory leak somewhere. You may need to use some python memory trace tools.

JohannesBuchner avatar Feb 05 '24 10:02 JohannesBuchner

Please reopen if you can reproduce this issue with a non-jax toy likelihood function.

This page suggests you can use ulimit or prlimit to limit the memory allowance of a program.

JohannesBuchner avatar May 21 '24 15:05 JohannesBuchner