numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

Adding a new MCMC method

Open reubenharry opened this issue 1 year ago • 5 comments

I'm a collaborator on this project https://github.com/JakobRobnik/MicroCanonicalHMC, and we're interested in either adding our algorithm to NumPyro, or using NumPyro in our codebase. With that in mind, we have a couple of questions:

  1. The simplest thing we'd like to do is to be able to write a probabilistic program in Numpyro like
def rosenbrock(d, Q):
    x = numpyro.sample("x", dist.Normal(jnp.ones(d // 2), jnp.ones(d // 2)))
    numpyro.sample("y", dist.Normal(jnp.square(x), np.sqrt(Q) * jnp.ones(d // 2)))

and then be able to extract the density function $f : \mathbb{R}^2 \to \mathbb{R}$. We want $f$ explicitly, because it's what we need to pass to our code in order to run our inference algorithm. However, we had some difficulty extracting it in a simple fashion from Numpyro, and I'm currently doing something a bit hacky, like:

vars = [init_model_trace[i]['name'] for i in init_model_trace]

    def potential_fn(arr):
        tr = trace(condition(model, dict(zip(vars, arr)) )).get_trace()
        return -sum([tr[x]['fn'].log_prob(arr[i]) for i, x in enumerate(tr)])

Is there a simpler and better way?

  1. We'd potentially be interested in adding our algorithm to NumPyro, as a kernel in addition to NUTS and HMC. Would that be of interest, and if so, do you have any guidance? The code for the HMC kernel looks quite complex, but perhaps there's a simpler example somewhere to follow.

Thanks!

Reuben

reubenharry avatar Oct 12 '23 18:10 reubenharry

hi @reubenharry -- yes it'd be great to add your mcmc method to numpyro and we can certainly help guide you along the process of getting a PR merged.

i see at least two options: i) you could implement a self-contained kernel that introduces no new dependencies; ii) you could implement a kernel that mostly consists of boiler plate and hands off the core of the algorithm to the MicroCanonicalHMC repo.

the benefit of the former is that there are no new dependencies, unit testing is contained within numpyro, etc, but the disadvantage is that any improvements to the algorithm won't filter down to numpyro without an explicit PR that implements new functionality.

the benefit of the second approach is that numpyro can benefit from any algorithm improvements in the upstream repo, but disadvantages include introducing a new dependency, the possibility that breaking changes are introduced upstream as well as the possibility that maintenance of the upstream repo slows down or ceases entirely.

currently i believe we have one instance of the second path, namely nested sampling which introduces a dependency on jaxns. @reubenharry do you have a preference? the second path probably only makes sense if you're pretty committed to maintaining the repo and if you foresee the algorithm evolving for the better over time.

@fehiepsi do you have suggestions for point 1?

martinjankowiak avatar Oct 14 '23 18:10 martinjankowiak

Thanks for your advice! In the medium term, the first option seems appealing, particularly since we're also working in Jax. The kernel itself is actually quite simple; the difficulty is that there is autotuning code which is a little more involved, and it wasn't immediately obvious to me how much control over that I would have if I went with option number one. Furthermore, I got a little intimidated looking at the HMC code, which has a few layers of abstractions, but I'm sure with guidance it wouldn't be so hard to do something similar :)

Currently we've opted for a simpler third option, which is just to express a program in numpyro, extract its density function, and then use that in our repo. I also did something somewhat like option 1. You can see both here: https://github.com/JakobRobnik/MicroCanonicalHMC/pull/18/files#diff-e81cce67759d32ecde8fc48bb864dd0ac7ecc01286a35ceab268e62e9181c0e3

I'll discuss with the other developers and see what their preferences are - perhaps at some point further down the road we can all chat in person.

reubenharry avatar Oct 14 '23 18:10 reubenharry

Hi @reubenharry you can make a new kernel as a subclass of MCMCKernel. To convert a numpyro model to a potential function, you can use initialize_model helper. This helper also returns a postprocess function to convert unconstrained values into constrained values - which can be used in the postprocess_fn method of the MCMCKernel. Let us know if something is unclear.

fehiepsi avatar Oct 14 '23 20:10 fehiepsi

Thanks! I'm having one issue:

def m():

    mu = numpyro.sample('mu', dist.Normal(3, 1))
    nu = numpyro.sample('nu', dist.Normal(mu+1, 2))

rng_key = jax.random.PRNGKey(0)
rng_key, init_key = jax.random.split(rng_key)
init_params, potential_fn_gen, *_ = initialize_model(
    init_key,
    m,
    model_args=(),
    dynamic_args=True,
)

print(potential_fn_gen()(jnp.array([4,2])))

>   File "/opt/homebrew/lib/python3.11/site-packages/numpyro/distributions/continuous.py", line 2035, in sample
    assert is_prng_key(key)
> AssertionError

reubenharry avatar Oct 24 '23 09:10 reubenharry

This should work

print(potential_fn_gen()(init_params.z))

Please check howinit_params.z looks.

tare avatar Oct 24 '23 13:10 tare

please check out https://num.pyro.ai/en/latest/tutorials/other_samplers.html for using numpyro with other samplers

fehiepsi avatar Aug 10 '24 22:08 fehiepsi