numpyro
numpyro copied to clipboard
Customize MCMC algorithm
Hi all, thanks for your powerful library. Sorry for my beginner question. Now I have implemented a Metropolis-Hasting algorithm following this link, I need to customize the score calculation. My code:
from collections import namedtuple
import os, time
import jax
import copy
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, BarkerMH
MHState = namedtuple("MHState", ["u", "score", "rng_key"])
class MetropolisHastings(numpyro.infer.mcmc.MCMCKernel):
sample_field = "u"
def __init__(self, potential_fn, translation_fn, beta=100, step_size=0.1):
self.potential_fn = potential_fn
self.translation_fn = translation_fn
self.step_size = step_size
self.beta = beta
def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
u = init_params
return MHState(u, self.potential_fn(u, **model_kwargs["potential_fn_kwargs"]), rng_key)
def sample(self, state, model_args, model_kwargs):
u, score, rng_key = state
rng_key, key_accept, key_translation = random.split(rng_key, 3)
u_proposal = self.translation_fn(u, key_translation, self.step_size, **model_kwargs["translation_fn_kwargs"])
score_new = self.potential_fn(u_proposal, **model_kwargs["potential_fn_kwargs"])
accept_prob = jnp.exp(self.beta*(score - score_new)) # exp(-beta*s_next)/exp(-beta*s_now)
alpha = dist.Uniform().sample(key_accept)
u_new = jnp.where(alpha < accept_prob, u_proposal, u)
score_new = jnp.where(alpha < accept_prob, score_new, score)
return MHState(u_new, score_new, rng_key)
def potential_fn(params, constraints, verbose, print_func=None):
# do something to calc the score
return score
def translation_fn(params, key, step_size):
# do something to translate params
# it is a symmetric translation, so I can use the simple formulation of accept probability in Metropolis-Hastings algorithm
return params_new
kernel = MetropolisHastings(potential_fn=potential_fn, translation_fn=translation_fn)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=10000, num_chains=1, progress_bar=True)
# trainable_params: the params to be optimized
# constrains: the constraints of the params, it is used to calc the score in function `potential_fn`
# verbose: whether to print the debug info in function `potential_fn`
mcmc.run(random.PRNGKey(0), init_params=trainable_params, extra_fields=('score',),
potential_fn_kwargs=dict(constrains=constrains, verbose=False),
translation_fn_kwargs=dict(),
)
It works, but this algorithm is low-performance. I found that the numpyro has so many amazing random algorithms(https://num.pyro.ai/en/stable/mcmc.html), I want to replace this simple MH algorithm with them, how can I do it?
I found that this code maybe work? But I don't know how to use my translation_fn
from collections import namedtuple
import os, time
import jax
import copy
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, BarkerMH, NUTS
MHState = namedtuple("MHState", ["u", "score", "rng_key"])
class MetropolisHastings(numpyro.infer.mcmc.MCMCKernel):
sample_field = "u"
def __init__(self, potential_fn, translation_fn, beta=100, step_size=0.1):
self.potential_fn = potential_fn
self.translation_fn = translation_fn
self.step_size = step_size
self.beta = beta
def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
u = init_params
return MHState(u, self.potential_fn(u, **model_kwargs["potential_fn_kwargs"]), rng_key)
def sample(self, state, model_args, model_kwargs):
u, score, rng_key = state
rng_key, key_accept, key_translation = random.split(rng_key, 3)
u_proposal = self.translation_fn(u, key_translation, self.step_size, **model_kwargs["translation_fn_kwargs"])
score_new = self.potential_fn(u_proposal, **model_kwargs["potential_fn_kwargs"])
accept_prob = jnp.exp(self.beta*(score - score_new)) # exp(-beta*s_next)/exp(-beta*s_now)
alpha = dist.Uniform().sample(key_accept)
u_new = jnp.where(alpha < accept_prob, u_proposal, u)
score_new = jnp.where(alpha < accept_prob, score_new, score)
return MHState(u_new, score_new, rng_key)
constraints = None # claims as global variable
def potential_fn(params, verbose=False, print_func=None):
# do something to calc the score
return score
# def translation_fn(params, key, step_size):
# # do something to translate params
# # it is a symmetric translation, so I can use the simple formulation of accept probability in Metropolis-Hastings algorithm
# return params_new
kernel = NUTS(potential_fn=potential_fn)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=10000, num_chains=1, progress_bar=True)
# trainable_params: the params to be optimized
# constrains: the constraints of the params, it is used to calc the score in function `potential_fn`
# verbose: whether to print the debug info in function `potential_fn`
mcmc.run(random.PRNGKey(0), init_params=trainable_params, extra_fields=('potential_energy',))
scores = mcmc.get_extra_fields()['potential_energy']
idx = jnp.argmin(scores)
result = mcmc.get_samples()['params'][idx]
Hi @DQSSSSS, could you clarify your question? Is the code not working or something?
Hi @fehiepsi, sorry for the ambiguity. This code is working but it uses the simple MH algorithm, I want to use some amazing algorithms in numpyro such as NUTS, HMCECS, etc. and select the best one.
I don't know how to define my translation function translation_fn and use NUTS at the same time, I have tried my best but I got the code like my second comment.
Using the default translation function makes my algorithm run slow because of the complexity of the problem so I need it.
Because your translation function requires random key, I think you need to wrap the NUTS kernel for such logic. You can see HMCGibbs implementation for a design. We have a pending issue #898 for composing the kernels.
Closed. Please feel free to follow up the discussion on the forum: https://forum.pyro.ai/