pymc
pymc copied to clipboard
Expose all `nutpie` compile backends through `pm.sample`
Description
Nutpie currently has two compile modes, numba and JAX, with a 3rd pytorch backend on the way. It would be nice if we could easily access these via pm.sample
.
Proposal 1: Allow nutpie.compile_pymc
kwargs in nuts_sampler_kwargs
- Pros: It's easy, since there are only two such arguments:
backend
andgradient-backend
. We just check for and pop them before forwarding all other arguments tonutpie.sample
. - Cons: It might be see as "unexpected" behavior, since some keywords are going to one function, and some to another. Also, the
nuts_sampler_kwargs
argument isn't very beautiful in the first place
Proposal 2: pip-style optional arguments, like nuts_sampler="nutpie[jax]"
and nuts_sampler="nutpie[numba]
"
- Pros: It's quite pretty!
- Cons: technically you can pick both the forward and backward compile mode, so if a user wanted that, she'd still have to import
nutpie
and do it manually. Maybe that's enough of a corner case that it's ok? Also it's a different API to other samplers (although blackjax could benefit from something similar to ask for the many different options over there -- but that's beyond the scope here).
Proposal 3: Add a new compile_kwargs
argument to pm.sample
- Pros: It's very clear. It could be used to forward kwargs to pytensor as well, which is a nice side bonus.
- Cons: It's another argument to an already bloated
pm.sample
function