nutpie icon indicating copy to clipboard operation
nutpie copied to clipboard

Document/expose from_pyfunc?

Open WardBrian opened this issue 1 year ago • 2 comments

Hi!

I think it is fairly common for a user to have a log density/its gradient outside of a PyMC or Stan model. I was helping one such user on the Stan forums recently and recommended they check out this package.

I was able to write what I think is a working example of how you can do this in nutpie, but

  1. It's not clear to me if this is really a "supported" functionality that won't be broken by future changes
  2. I had to basically reverse-engineer how to call the function, due the lack of documentation

Is this a use case you all are interested in supporting?

WardBrian avatar Aug 14 '24 19:08 WardBrian

Not super helpful, but @aseyboldt helped me get nutpie working with JAX in bayeux: https://github.com/jax-ml/bayeux/blob/main/bayeux/_src/mcmc/nutpie.py

Broadly, bayeux accepts a log density in JAX, then uses function transforms to compute gradients, transforms, and log det jacobians. The implementation I linked does this by:

  1. Flattening the inputs to bayeux, to make a log density that works on a single vector (rather than some other structure)
  2. Passing this wrapped (and transformed) log density, along with the gradient, to nutpie
  3. Untransforming the returned samples

ColCarroll avatar Aug 14 '24 20:08 ColCarroll

Yes, this is definitely something we want to support. I did not yet expose it as a public function because I thought I might want to do some minor changes to the API. But the next version should contain a documented version of this function, it might look slightly different, but should be really easy to adapt.

On Wed, 14 Aug 2024, 21:05 Colin Carroll, @.***> wrote:

Not super helpful, but @aseyboldt https://github.com/aseyboldt helped me get nutpie working with JAX in bayeux: https://github.com/jax-ml/bayeux/blob/main/bayeux/_src/mcmc/nutpie.py

Broadly, bayeux accepts a log density in JAX, then uses function transforms to compute gradients, transforms, and log det jacobians. The implementation I linked does this by:

  1. Flattening the inputs to bayeux, to make a log density that works on a single vector (rather than some other structure)
  2. Passing this wrapped (and transformed) log density, along with the gradient, to nutpie
  3. Untransforming the returned samples

— Reply to this email directly, view it on GitHub https://github.com/pymc-devs/nutpie/issues/146#issuecomment-2289772381, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAOLSHPTVEWRTDVPZT4RJS3ZRO2AFAVCNFSM6AAAAABMQ7UZUKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEOBZG43TEMZYGE . You are receiving this because you were mentioned.Message ID: @.***>

aseyboldt avatar Aug 14 '24 20:08 aseyboldt