numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard


Open dirmeier opened this issue 2 years ago • 11 comments

Hi all, are there any plans to implement pathfinder (arXiv:2108.03782) ? If so, I would be up for implementing it.

Cheers, Simon

dirmeier avatar Oct 02 '22 10:10 dirmeier

Hi Simon, we don't have plan to implement that algorithm. Welcome to contribute!

fehiepsi avatar Oct 02 '22 11:10 fehiepsi

@dirmeier BlackJax ( has implemented pathfinder and it's compatible with NumPyro models, I got this script that enables running BlackJax's pathfinder on NumPyro models:

import jax
import numpyro
from numpyro.infer.util import initialize_model
import blackjax

def inference_kit_generator(model_func):
    model_func: A NumPyro model function
    key = jax.random.PRNGKey(0) # The key doesn't matter here
    def _inner(*args, **kwargs):
        init_funcs = initialize_model(
        param_template = init_funcs.param_info.z
        potential_func = init_funcs.potential_fn
        transform_func = init_funcs.postprocess_fn
        flatten_func = lambda p: ravel_pytree(p)[0]
        unflatten_func = ravel_pytree(param_template)[1]
        return {
            'param_template': param_template,
            'potential_func': potential_func,
            'transform_func': transform_func,
            'flatten_func': flatten_func,
            'unflatten_func': unflatten_func
    return _inner

def pathfinder_vi(model_func, seed, num_samples, *args, **kwargs):
    model_func: A NumPyro model function
    seed: Random seed
    num_samples: Number of samples to generate from the variatinoal posterior
    *args, **kwargs: Input data for the model
    inference_kit = inference_kit_generator(model_fun)(*args, **kwargs)
    param_template = inference_kit["param_template"]
    potential_func = jax.jit(inference_kit["potential_func"])
    transform_func = inference_kit["transform_func"]
    flatten_func = inference_kit["flatten_func"]
    unflatten_func = inference_kit["unflatten_func"]

    flattened_params = flatten_func(param_template)
    key = PRNGKey(seed)

    def inference_loop(rng_key, kernel, initial_state, num_samples):
        def one_step(state, rng_key):
            state, info = kernel(rng_key, state)
            return state, (state, info)
        keys = jax.random.split(rng_key, num_samples)
        return jax.lax.scan(one_step, initial_state, keys)

    logprob_fn = lambda x: -potential_func(x)
    key, _key = split(key)
    w0 = jax.random.normal(_key, (len(flattened_params),))
    w0 = unflatten_func(w0)
    key, _key = split(key)
    pathfinder = blackjax.kernels.pathfinder(_key, logprob_fn, maxiter=10000)
    state = pathfinder.init(w0)
    key, _key = split(key)
    _, (_, samples) = inference_loop(_key, pathfinder.step, state, num_samples)
    return samples

I've also thought about implementing pathfinder in pure NumPyro and Jax but it seems that we would need a more sophisticated version of LBFGS (jax.scipy.optimize.minimize doesn't seem to be flexible enough..)

xidulu avatar Dec 07 '22 15:12 xidulu

Hi @xidulu , yeah I saw that some days after I wrote the original post and have been using it in the meantime. I still think it would be nice to have this in NumPyro, too, and I was planning on doing it over the christmas holidays.

dirmeier avatar Dec 07 '22 22:12 dirmeier

It is a shame to duplicate efforts between Blackjax and Numpyro when we could at the same time have a better PPL and more samplers 😁

rlouf avatar Dec 11 '22 17:12 rlouf

@dirmeier Please feel free to work on this if you find that it is easier to use when having a numpyro implementation. Or defer to blackjax if you find that no need to duplicate efforts. Thanks!

@xidulu jaxopt has LBFGS implemented. We're trying to do some benchmarks on the top of it to see if it's fast (upon batching) and stable. I'm not sure if it is flexible enough for pathfinder.

Good to know that blackjax has this sampler. It is great for the community.

fehiepsi avatar Dec 16 '22 16:12 fehiepsi

@xidulu Looking at your script, it seems that the MCMCKernel api can serve as a wrapper for it - there we can

  • converting a model to potential function
  • perform an update step and scan the samples I guess this pattern works for many blackjax samplers. Somehow this is similar to how we wrap TFPKernel to use TFP samplers for numpyro models.

This way we can take some benefits of MCMC class which allows us to get unconstrained samples, deterministic samples, diagnostics, vectorized/parallel chains, progress bar, thinning,... ArviZ is also compatible with MCMC class so we can do further exploratory analysis.

fehiepsi avatar Dec 16 '22 16:12 fehiepsi

Hi! I'm trying to use blackjax's pathfinder method with numpyro with above. This gets me the samples in the transformed space; how can I get them to the original space? Thanks

I have tried

samples_transformed = jax.vmap( transform_func(*args, **kwargs))(samples)

But got non-sense values ...

juanitorduz avatar Dec 01 '23 22:12 juanitorduz

Could you provide reproducible code? The above approach uses the default dynamic_args=False in the call initialize_model. As mentioned in initialize_model docs, transform_func will not take args, kwargs. I wonder why your code still run.

fehiepsi avatar Dec 03 '23 17:12 fehiepsi

Thanks @fehiepsi ! Here is an example (sorry if I am doing something non-sense!)

import arviz as az
import blackjax
import jax
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist

from jax import random
from numpyro.infer.util import initialize_model"arviz-darkgrid")
plt.rcParams["figure.figsize"] = [7, 6]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"


rng_key = random.PRNGKey(seed=0)

n = 100

rng_key, rng_subkey = random.split(rng_key)
x = random.normal(rng_subkey, (n,))

a = 1.0  #  <- true values to recover
b = 2.0  #  <- true values to recover
sigma = 0.5. #  <- true values to recover

rng_key, rng_subkey = random.split(rng_key)
epsilon = sigma * random.normal(rng_subkey, (n,))

y = a + b * x + epsilon

fig, ax = plt.subplots()
ax.plot(x, y, "o")
ax.set(xlabel="x", ylabel="y", title="Raw Data")


rng_key, rng_subkey = random.split(rng_key)
param_info, potential_fn, postprocess_fn, *_ = initialize_model(
    model_args=(x, y),

def logdensity_fn(position):
    func = potential_fn(x, y)
    return -func(position)

initial_position = param_info.z

rng_key, rng_subkey = random.split(rng_key)
pathfinder_state, _ =
    logdensity_fn=lambda x: -logdensity_fn(x),

rng_key, rng_subkey = random.split(rng_key)
posterior_samples, _ =

idata_pathfinder = az.from_dict(
        k: np.expand_dims(a=np.asarray(v), axis=0) for k, v in posterior_samples.items()

axes = az.plot_trace(
    backend_kwargs={"figsize": (12, 6), "layout": "constrained"},
plt.gcf().suptitle(t="Pathfinder Trace", fontsize=18, fontweight="bold")



posterior_samples_transformed = jax.vmap(postprocess_fn(x, y))(posterior_samples)

>> {'a': Array([49.598484, 47.983875, 50.057056, ..., 49.849007, 50.13445 ,
        50.864384], dtype=float32),
 'b': Array([2353.3398, 2354.5476, 2352.5825, ..., 2352.4006, 2353.875 ,
        2354.1428], dtype=float32),
 'sigma': Array([inf, inf, inf, ..., inf, inf, inf], dtype=float32)}

What could I be missing?

Thanks ☺️

juanitorduz avatar Dec 03 '23 18:12 juanitorduz

So you're using dynamics_args=True, which is different. Then postprocess_fn should do such a transform. Did you check your blackjax results? You may also want to check if you get correct results with manual transforms. Also check for what you provide to blackjax, like whether logdensity_fn has correct sign.

fehiepsi avatar Dec 03 '23 19:12 fehiepsi

So you're using dynamics_args=True, which is different. Then postprocess_fn should do such a transform. Did you check your blackjax results? You may also want to check if you get correct results with manual transforms. Also check for what you provide to blackjax, like whether logdensity_fn has correct sign.

OMG 🤦 , actually

def logdensity_fn(position):
    func = potential_fn(x, y)
    return func(position)

had the wrong sign (the minus sign was wrong) and now is working as expected:




>> {'a': Array([0.9645677 , 0.88383734, 0.9891541 , ..., 0.98004097, 0.9882706 ,
        1.0226959 ], dtype=float32),
 'b': Array([1.9266534, 1.9878191, 1.8892442, ..., 1.8811033, 1.9484416,
        1.9543549], dtype=float32),
 'sigma': Array([0.49401087, 0.49979746, 0.46184635, ..., 0.44290245, 0.51779646,
        0.49234945], dtype=float32)}

I hope this can help other users :) Thanks!

Is there is still interest in having a wrapper around MCMCKernel ?

juanitorduz avatar Dec 03 '23 20:12 juanitorduz

@xidulu Looking at your script, it seems that the MCMCKernel api can serve as a wrapper for it - there we can

  • converting a model to potential function
  • perform an update step and scan the samples I guess this pattern works for many blackjax samplers. Somehow this is similar to how we wrap TFPKernel to use TFP samplers for numpyro models.

This way, we can take some benefits of MCMC class, which allows us to get unconstrained samples, deterministic samples, diagnostics, vectorized/parallel chains, progress bar, thinning,... ArviZ is also compatible with MCMC class so we can do further exploratory analysis.

@fehiepsi Is this still open? Meaning, wrapping this in the MCMC class? Or is it enough to have an example in the documentation?

juanitorduz avatar Aug 01 '24 16:08 juanitorduz

Hi @juanitorduz I think it be great to have a tutorial / example showing how to use numpyro with other libraries via initialize_model. If you think that the MCMC class is helpful, you can also show in the example how to define a MCMCKernel for the inference algorithm.

The issues are also related.

fehiepsi avatar Aug 01 '24 19:08 fehiepsi

Ok! I can work on the example! I actually wrote something for myself based on this issue is this example too easy? Shall I do something more complex? I did an easy one not to get lost in the model but rather in the Pathfinder part. But maybe NumPyro’s audience wants a less “trivial” example. I'm easy :)

juanitorduz avatar Aug 01 '24 19:08 juanitorduz

I would prefer simple examples to complex ones. :)

fehiepsi avatar Aug 01 '24 19:08 fehiepsi

I would prefer simple examples to complex ones. :)

Same here 😀 !

juanitorduz avatar Aug 01 '24 20:08 juanitorduz

please check out for how to use numpyro with other apis

fehiepsi avatar Aug 10 '24 22:08 fehiepsi