numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

Customize MCMC algorithm

Open DQSSSSS opened this issue 2 years ago • 4 comments

Hi all, thanks for your powerful library. Sorry for my beginner question. Now I have implemented a Metropolis-Hasting algorithm following this link, I need to customize the score calculation. My code:

from collections import namedtuple
import os, time
import jax
import copy
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, BarkerMH

MHState = namedtuple("MHState", ["u", "score", "rng_key"])

class MetropolisHastings(numpyro.infer.mcmc.MCMCKernel):
    sample_field = "u"

    def __init__(self, potential_fn, translation_fn, beta=100, step_size=0.1):
        self.potential_fn = potential_fn
        self.translation_fn = translation_fn
        self.step_size = step_size
        self.beta = beta

    def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
        u = init_params
        return MHState(u, self.potential_fn(u, **model_kwargs["potential_fn_kwargs"]), rng_key)

    def sample(self, state, model_args, model_kwargs):
        u, score, rng_key = state
        rng_key, key_accept, key_translation = random.split(rng_key, 3)
        u_proposal = self.translation_fn(u, key_translation, self.step_size, **model_kwargs["translation_fn_kwargs"])
        score_new = self.potential_fn(u_proposal, **model_kwargs["potential_fn_kwargs"])
        accept_prob = jnp.exp(self.beta*(score - score_new)) # exp(-beta*s_next)/exp(-beta*s_now)
        alpha = dist.Uniform().sample(key_accept)
        u_new  = jnp.where(alpha < accept_prob, u_proposal, u)
        score_new  = jnp.where(alpha < accept_prob, score_new, score)
        return MHState(u_new, score_new, rng_key)

def potential_fn(params, constraints, verbose, print_func=None):
    # do something to calc the score
    return score

def translation_fn(params, key, step_size):
    # do something to translate params
    # it is a symmetric translation, so I can use the simple formulation of accept probability in Metropolis-Hastings algorithm
    return params_new

kernel = MetropolisHastings(potential_fn=potential_fn, translation_fn=translation_fn)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=10000, num_chains=1, progress_bar=True)
# trainable_params: the params to be optimized
# constrains: the constraints of the params, it is used to calc the score in function `potential_fn`
# verbose: whether to print the debug info in function `potential_fn`
mcmc.run(random.PRNGKey(0), init_params=trainable_params, extra_fields=('score',),
        potential_fn_kwargs=dict(constrains=constrains, verbose=False),
        translation_fn_kwargs=dict(),
)

It works, but this algorithm is low-performance. I found that the numpyro has so many amazing random algorithms(https://num.pyro.ai/en/stable/mcmc.html), I want to replace this simple MH algorithm with them, how can I do it?

DQSSSSS avatar Aug 24 '23 15:08 DQSSSSS

I found that this code maybe work? But I don't know how to use my translation_fn

from collections import namedtuple
import os, time
import jax
import copy
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, BarkerMH, NUTS

MHState = namedtuple("MHState", ["u", "score", "rng_key"])

class MetropolisHastings(numpyro.infer.mcmc.MCMCKernel):
    sample_field = "u"

    def __init__(self, potential_fn, translation_fn, beta=100, step_size=0.1):
        self.potential_fn = potential_fn
        self.translation_fn = translation_fn
        self.step_size = step_size
        self.beta = beta

    def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
        u = init_params
        return MHState(u, self.potential_fn(u, **model_kwargs["potential_fn_kwargs"]), rng_key)

    def sample(self, state, model_args, model_kwargs):
        u, score, rng_key = state
        rng_key, key_accept, key_translation = random.split(rng_key, 3)
        u_proposal = self.translation_fn(u, key_translation, self.step_size, **model_kwargs["translation_fn_kwargs"])
        score_new = self.potential_fn(u_proposal, **model_kwargs["potential_fn_kwargs"])
        accept_prob = jnp.exp(self.beta*(score - score_new)) # exp(-beta*s_next)/exp(-beta*s_now)
        alpha = dist.Uniform().sample(key_accept)
        u_new  = jnp.where(alpha < accept_prob, u_proposal, u)
        score_new  = jnp.where(alpha < accept_prob, score_new, score)
        return MHState(u_new, score_new, rng_key)

constraints = None # claims as global variable

def potential_fn(params, verbose=False, print_func=None):
    # do something to calc the score
    return score

# def translation_fn(params, key, step_size):
#     # do something to translate params
#     # it is a symmetric translation, so I can use the simple formulation of accept probability in Metropolis-Hastings algorithm
#     return params_new

kernel = NUTS(potential_fn=potential_fn)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=10000, num_chains=1, progress_bar=True)
# trainable_params: the params to be optimized
# constrains: the constraints of the params, it is used to calc the score in function `potential_fn`
# verbose: whether to print the debug info in function `potential_fn`
mcmc.run(random.PRNGKey(0), init_params=trainable_params, extra_fields=('potential_energy',))
scores = mcmc.get_extra_fields()['potential_energy']
idx = jnp.argmin(scores)
result = mcmc.get_samples()['params'][idx]

DQSSSSS avatar Aug 24 '23 19:08 DQSSSSS

Hi @DQSSSSS, could you clarify your question? Is the code not working or something?

fehiepsi avatar Aug 28 '23 14:08 fehiepsi

Hi @fehiepsi, sorry for the ambiguity. This code is working but it uses the simple MH algorithm, I want to use some amazing algorithms in numpyro such as NUTS, HMCECS, etc. and select the best one. I don't know how to define my translation function translation_fn and use NUTS at the same time, I have tried my best but I got the code like my second comment. Using the default translation function makes my algorithm run slow because of the complexity of the problem so I need it.

DQSSSSS avatar Aug 28 '23 14:08 DQSSSSS

Because your translation function requires random key, I think you need to wrap the NUTS kernel for such logic. You can see HMCGibbs implementation for a design. We have a pending issue #898 for composing the kernels.

fehiepsi avatar Aug 28 '23 22:08 fehiepsi

Closed. Please feel free to follow up the discussion on the forum: https://forum.pyro.ai/

fehiepsi avatar May 12 '24 12:05 fehiepsi