pymc icon indicating copy to clipboard operation
pymc copied to clipboard

Expose all `nutpie` compile backends through `pm.sample`

Open jessegrabowski opened this issue 5 months ago • 3 comments

Description

Nutpie currently has two compile modes, numba and JAX, with a 3rd pytorch backend on the way. It would be nice if we could easily access these via pm.sample.

Proposal 1: Allow nutpie.compile_pymc kwargs in nuts_sampler_kwargs

  • Pros: It's easy, since there are only two such arguments: backend and gradient-backend. We just check for and pop them before forwarding all other arguments to nutpie.sample.
  • Cons: It might be see as "unexpected" behavior, since some keywords are going to one function, and some to another. Also, the nuts_sampler_kwargs argument isn't very beautiful in the first place

Proposal 2: pip-style optional arguments, like nuts_sampler="nutpie[jax]" and nuts_sampler="nutpie[numba]"

  • Pros: It's quite pretty!
  • Cons: technically you can pick both the forward and backward compile mode, so if a user wanted that, she'd still have to import nutpie and do it manually. Maybe that's enough of a corner case that it's ok? Also it's a different API to other samplers (although blackjax could benefit from something similar to ask for the many different options over there -- but that's beyond the scope here).

Proposal 3: Add a new compile_kwargs argument to pm.sample

  • Pros: It's very clear. It could be used to forward kwargs to pytensor as well, which is a nice side bonus.
  • Cons: It's another argument to an already bloated pm.sample function

jessegrabowski avatar Sep 10 '24 04:09 jessegrabowski