numpyro
numpyro copied to clipboard
Pathfinder
Hi all, are there any plans to implement pathfinder (arXiv:2108.03782) ? If so, I would be up for implementing it.
Cheers, Simon
Hi Simon, we don't have plan to implement that algorithm. Welcome to contribute!
@dirmeier BlackJax (https://github.com/blackjax-devs/blackjax) has implemented pathfinder and it's compatible with NumPyro models, I got this script that enables running BlackJax's pathfinder on NumPyro models:
import jax
import numpyro
from numpyro.infer.util import initialize_model
import blackjax
def inference_kit_generator(model_func):
'''
model_func: A NumPyro model function
'''
key = jax.random.PRNGKey(0) # The key doesn't matter here
def _inner(*args, **kwargs):
init_funcs = initialize_model(
key,
model_func,
model_args=args,
model_kwargs=kwargs
)
param_template = init_funcs.param_info.z
potential_func = init_funcs.potential_fn
transform_func = init_funcs.postprocess_fn
flatten_func = lambda p: ravel_pytree(p)[0]
unflatten_func = ravel_pytree(param_template)[1]
return {
'param_template': param_template,
'potential_func': potential_func,
'transform_func': transform_func,
'flatten_func': flatten_func,
'unflatten_func': unflatten_func
}
return _inner
def pathfinder_vi(model_func, seed, num_samples, *args, **kwargs):
"""
model_func: A NumPyro model function
seed: Random seed
num_samples: Number of samples to generate from the variatinoal posterior
*args, **kwargs: Input data for the model
"""
inference_kit = inference_kit_generator(model_fun)(*args, **kwargs)
param_template = inference_kit["param_template"]
potential_func = jax.jit(inference_kit["potential_func"])
transform_func = inference_kit["transform_func"]
flatten_func = inference_kit["flatten_func"]
unflatten_func = inference_kit["unflatten_func"]
flattened_params = flatten_func(param_template)
key = PRNGKey(seed)
def inference_loop(rng_key, kernel, initial_state, num_samples):
@jax.jit
def one_step(state, rng_key):
state, info = kernel(rng_key, state)
return state, (state, info)
keys = jax.random.split(rng_key, num_samples)
return jax.lax.scan(one_step, initial_state, keys)
logprob_fn = lambda x: -potential_func(x)
key, _key = split(key)
w0 = jax.random.normal(_key, (len(flattened_params),))
w0 = unflatten_func(w0)
key, _key = split(key)
pathfinder = blackjax.kernels.pathfinder(_key, logprob_fn, maxiter=10000)
state = pathfinder.init(w0)
key, _key = split(key)
_, (_, samples) = inference_loop(_key, pathfinder.step, state, num_samples)
return samples
I've also thought about implementing pathfinder in pure NumPyro and Jax but it seems that we would need a more sophisticated version of LBFGS (jax.scipy.optimize.minimize doesn't seem to be flexible enough..)
Hi @xidulu , yeah I saw that some days after I wrote the original post and have been using it in the meantime. I still think it would be nice to have this in NumPyro, too, and I was planning on doing it over the christmas holidays.
It is a shame to duplicate efforts between Blackjax and Numpyro when we could at the same time have a better PPL and more samplers 😁
@dirmeier Please feel free to work on this if you find that it is easier to use when having a numpyro implementation. Or defer to blackjax if you find that no need to duplicate efforts. Thanks!
@xidulu jaxopt
has LBFGS implemented. We're trying to do some benchmarks on the top of it to see if it's fast (upon batching) and stable. I'm not sure if it is flexible enough for pathfinder.
Good to know that blackjax has this sampler. It is great for the community.
@xidulu Looking at your script, it seems that the MCMCKernel api can serve as a wrapper for it - there we can
- converting a model to potential function
- perform an update step and scan the samples I guess this pattern works for many blackjax samplers. Somehow this is similar to how we wrap TFPKernel to use TFP samplers for numpyro models.
This way we can take some benefits of MCMC class which allows us to get unconstrained samples, deterministic samples, diagnostics, vectorized/parallel chains, progress bar, thinning,... ArviZ is also compatible with MCMC class so we can do further exploratory analysis.
Hi! I'm trying to use blackjax's pathfinder method with numpyro with above. This gets me the samples in the transformed space; how can I get them to the original space? Thanks
I have tried
samples_transformed = jax.vmap( transform_func(*args, **kwargs))(samples)
But got non-sense values ...
Could you provide reproducible code? The above approach uses the default dynamic_args=False
in the call initialize_model
. As mentioned in initialize_model docs, transform_func
will not take args, kwargs. I wonder why your code still run.
Thanks @fehiepsi ! Here is an example (sorry if I am doing something non-sense!)
import arviz as az
import blackjax
import jax
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer.util import initialize_model
az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [7, 6]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"
numpyro.set_host_device_count(n=4)
rng_key = random.PRNGKey(seed=0)
n = 100
rng_key, rng_subkey = random.split(rng_key)
x = random.normal(rng_subkey, (n,))
a = 1.0 # <- true values to recover
b = 2.0 # <- true values to recover
sigma = 0.5. # <- true values to recover
rng_key, rng_subkey = random.split(rng_key)
epsilon = sigma * random.normal(rng_subkey, (n,))
y = a + b * x + epsilon
fig, ax = plt.subplots()
ax.plot(x, y, "o")
ax.set(xlabel="x", ylabel="y", title="Raw Data")
rng_key, rng_subkey = random.split(rng_key)
param_info, potential_fn, postprocess_fn, *_ = initialize_model(
rng_subkey,
model,
model_args=(x, y),
dynamic_args=True,
)
def logdensity_fn(position):
func = potential_fn(x, y)
return -func(position)
initial_position = param_info.z
rng_key, rng_subkey = random.split(rng_key)
pathfinder_state, _ = blackjax.vi.pathfinder.approximate(
rng_key=rng_subkey,
logdensity_fn=lambda x: -logdensity_fn(x),
initial_position=initial_position,
ftol=1e-4,
)
rng_key, rng_subkey = random.split(rng_key)
posterior_samples, _ = blackjax.vi.pathfinder.sample(
rng_key=rng_subkey,
state=pathfinder_state,
num_samples=4_000,
)
idata_pathfinder = az.from_dict(
posterior={
k: np.expand_dims(a=np.asarray(v), axis=0) for k, v in posterior_samples.items()
},
)
axes = az.plot_trace(
data=idata_pathfinder,
compact=True,
backend_kwargs={"figsize": (12, 6), "layout": "constrained"},
)
plt.gcf().suptitle(t="Pathfinder Trace", fontsize=18, fontweight="bold")
Also,
posterior_samples_transformed = jax.vmap(postprocess_fn(x, y))(posterior_samples)
>> {'a': Array([49.598484, 47.983875, 50.057056, ..., 49.849007, 50.13445 ,
50.864384], dtype=float32),
'b': Array([2353.3398, 2354.5476, 2352.5825, ..., 2352.4006, 2353.875 ,
2354.1428], dtype=float32),
'sigma': Array([inf, inf, inf, ..., inf, inf, inf], dtype=float32)}
What could I be missing?
Thanks ☺️
So you're using dynamics_args=True
, which is different. Then postprocess_fn
should do such a transform. Did you check your blackjax results? You may also want to check if you get correct results with manual transforms. Also check for what you provide to blackjax, like whether logdensity_fn
has correct sign.
So you're using
dynamics_args=True
, which is different. Thenpostprocess_fn
should do such a transform. Did you check your blackjax results? You may also want to check if you get correct results with manual transforms. Also check for what you provide to blackjax, like whetherlogdensity_fn
has correct sign.
OMG 🤦 , actually
def logdensity_fn(position):
func = potential_fn(x, y)
return func(position)
had the wrong sign (the minus sign was wrong) and now is working as expected:
and
posterior_samples_transformed
>> {'a': Array([0.9645677 , 0.88383734, 0.9891541 , ..., 0.98004097, 0.9882706 ,
1.0226959 ], dtype=float32),
'b': Array([1.9266534, 1.9878191, 1.8892442, ..., 1.8811033, 1.9484416,
1.9543549], dtype=float32),
'sigma': Array([0.49401087, 0.49979746, 0.46184635, ..., 0.44290245, 0.51779646,
0.49234945], dtype=float32)}
I hope this can help other users :) Thanks!
Is there is still interest in having a wrapper around MCMCKernel ?
@xidulu Looking at your script, it seems that the MCMCKernel api can serve as a wrapper for it - there we can
- converting a model to potential function
- perform an update step and scan the samples I guess this pattern works for many blackjax samplers. Somehow this is similar to how we wrap TFPKernel to use TFP samplers for numpyro models.
This way, we can take some benefits of MCMC class, which allows us to get unconstrained samples, deterministic samples, diagnostics, vectorized/parallel chains, progress bar, thinning,... ArviZ is also compatible with MCMC class so we can do further exploratory analysis.
@fehiepsi Is this still open? Meaning, wrapping this in the MCMC class? Or is it enough to have an example in the documentation?
Hi @juanitorduz I think it be great to have a tutorial / example showing how to use numpyro with other libraries via initialize_model. If you think that the MCMC class is helpful, you can also show in the example how to define a MCMCKernel for the inference algorithm.
The issues https://github.com/pyro-ppl/numpyro/issues/1734 https://github.com/pyro-ppl/numpyro/issues/1662 https://github.com/pyro-ppl/numpyro/issues/950 are also related.
Ok! I can work on the example! I actually wrote something for myself based on this issue https://juanitorduz.github.io/numpyro_pathfinder/ is this example too easy? Shall I do something more complex? I did an easy one not to get lost in the model but rather in the Pathfinder part. But maybe NumPyro’s audience wants a less “trivial” example. I'm easy :)
I would prefer simple examples to complex ones. :)
I would prefer simple examples to complex ones. :)
Same here 😀 !
please check out https://num.pyro.ai/en/latest/tutorials/other_samplers.html for how to use numpyro with other apis