celerite2 icon indicating copy to clipboard operation
celerite2 copied to clipboard

Add support for JAX batching via vmap

Open dfm opened this issue 5 years ago • 3 comments

This will require updating the backend to iterate over the batch dimension, but that shouldn't be too terribly hard. Then, we'd need to add a simple batching function.

One question is how to interface batching with the terms interface.

https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html#Batching

dfm avatar Oct 29 '20 13:10 dfm

Hi @dfm! I'm starting to work with the jax backend, and I'm hitting an error:

NotImplementedError: Batching rule for 'celerite2_solve_lower_jvp' not implemented

Is that what you're referring to in this issue? Happy to help contribute if you give me some pointers on how to get started.

bmorris3 avatar Jun 26 '21 05:06 bmorris3

Ah I see this was because I left my NUTS(...,forward_mode_differentiation=True), sorry for the noise!

bmorris3 avatar Jun 26 '21 05:06 bmorris3