pymc icon indicating copy to clipboard operation
pymc copied to clipboard

Allow access to different nutpie backends via pip-style syntax

Open jessegrabowski opened this issue 5 months ago • 17 comments

Description

Adds a pip-style syntax to the nuts_sampler argument that allows access to alternative compile backends, when relevant. This lets you get the nutpie jax backend by setting nuts_sampler='nutpie[jax]'. For backwards compatibility, nuts_sampler='nutpie' is equivalent to nuts_sampler='nutpie[numba]'.

The current PR only deals with nutpie, but we could easily extend this to include the default PyMC sampler, to compile to JAX, numba, or pytorch directly, without going through nutpie. I'm willing to do that extension in this PR if it is deemed worthwhile..

Related Issue

  • [x] Closes #7497
  • [ ] Related to #

Checklist

Type of change

  • [x] New feature / enhancement
  • [ ] Bug fix
  • [ ] Documentation
  • [ ] Maintenance
  • [ ] Other (please specify):

📚 Documentation preview 📚: https://pymc--7498.org.readthedocs.build/en/7498/

jessegrabowski avatar Sep 10 '24 10:09 jessegrabowski