numpyro
numpyro copied to clipboard
[FR] SGLD and SG-MCMC inference in numpyro?
Hi there,
Thanks for creating numpyro!
I'm wondering, is there any interest in supporting SG-MCMC methods (such as SGLD, IASG, etc.) in numpyro?
I searched around a bit, and it seems like the pyro community doesn't find SGLD-style methods as useful (pyro-ppl/pyro#1921). I'd be happy to contribute though if there's any interest at all. Personally, I'd like to use numpyro for one of my projects, which requires SG-MCMC. I have a custom implementation of IASG in JAX, but it would be so much nicer to have it as yet another infer.mcmc.MCMCKernel. What do y'all think?
Welcome, @alshedivat! When SG-MCMC methods are available, I would like to see how it performs against HMCECS (contributed by @OlaRonning and his team in the last NumPyro release). :D
Update: checkout @jeremiecoullon's SGMCMCJax for SGLD and SGMCMC! :rocket:
It would be nice to have a tutorial on how to convert a NumPyro model to SGMCMCJax format (logprior / loglikelihood). Currently, we only have initialize_model utility to give us a potential function of a model together with initial values. I guess for logprior, we can just mask out the likelihood using numpyro.handlers.mask.
Glad that my SGMCMCJax library might be helpful here!
A tutorial to convert a NumPyro model to a logprior and a loglikelihood format would be great I agree. This is something that I will try to look into when I find the time (though anyone else: feel free to try this out!).
And I'm also curious to see how the different sgmcmc samplers perform compared to HMCECS