pymc
pymc copied to clipboard
Allow access to different nutpie backends via pip-style syntax
Description
Adds a pip-style syntax to the nuts_sampler
argument that allows access to alternative compile backends, when relevant. This lets you get the nutpie jax backend by setting nuts_sampler='nutpie[jax]'
. For backwards compatibility, nuts_sampler='nutpie'
is equivalent to nuts_sampler='nutpie[numba]'
.
The current PR only deals with nutpie, but we could easily extend this to include the default PyMC sampler, to compile to JAX, numba, or pytorch directly, without going through nutpie. I'm willing to do that extension in this PR if it is deemed worthwhile..
Related Issue
- [x] Closes #7497
- [ ] Related to #
Checklist
- [x] Checked that the pre-commit linting/style checks pass
- [x] Included tests that prove the fix is effective or that the new feature works
- [x] Added necessary documentation (docstrings and/or example notebooks)
- [x] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [x] New feature / enhancement
- [ ] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [ ] Other (please specify):
📚 Documentation preview 📚: https://pymc--7498.org.readthedocs.build/en/7498/