nutpie
nutpie copied to clipboard
Document/expose from_pyfunc?
Hi!
I think it is fairly common for a user to have a log density/its gradient outside of a PyMC or Stan model. I was helping one such user on the Stan forums recently and recommended they check out this package.
I was able to write what I think is a working example of how you can do this in nutpie, but
- It's not clear to me if this is really a "supported" functionality that won't be broken by future changes
- I had to basically reverse-engineer how to call the function, due the lack of documentation
Is this a use case you all are interested in supporting?
Not super helpful, but @aseyboldt helped me get nutpie working with JAX in bayeux: https://github.com/jax-ml/bayeux/blob/main/bayeux/_src/mcmc/nutpie.py
Broadly, bayeux accepts a log density in JAX, then uses function transforms to compute gradients, transforms, and log det jacobians. The implementation I linked does this by:
- Flattening the inputs to bayeux, to make a log density that works on a single vector (rather than some other structure)
- Passing this wrapped (and transformed) log density, along with the gradient, to
nutpie - Untransforming the returned samples
Yes, this is definitely something we want to support. I did not yet expose it as a public function because I thought I might want to do some minor changes to the API. But the next version should contain a documented version of this function, it might look slightly different, but should be really easy to adapt.
On Wed, 14 Aug 2024, 21:05 Colin Carroll, @.***> wrote:
Not super helpful, but @aseyboldt https://github.com/aseyboldt helped me get nutpie working with JAX in bayeux: https://github.com/jax-ml/bayeux/blob/main/bayeux/_src/mcmc/nutpie.py
Broadly, bayeux accepts a log density in JAX, then uses function transforms to compute gradients, transforms, and log det jacobians. The implementation I linked does this by:
- Flattening the inputs to bayeux, to make a log density that works on a single vector (rather than some other structure)
- Passing this wrapped (and transformed) log density, along with the gradient, to nutpie
- Untransforming the returned samples
— Reply to this email directly, view it on GitHub https://github.com/pymc-devs/nutpie/issues/146#issuecomment-2289772381, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAOLSHPTVEWRTDVPZT4RJS3ZRO2AFAVCNFSM6AAAAABMQ7UZUKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEOBZG43TEMZYGE . You are receiving this because you were mentioned.Message ID: @.***>