celerite2
celerite2 copied to clipboard
Add support for JAX batching via vmap
This will require updating the backend to iterate over the batch dimension, but that shouldn't be too terribly hard. Then, we'd need to add a simple batching function.
One question is how to interface batching with the terms interface.
https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html#Batching
Hi @dfm! I'm starting to work with the jax backend, and I'm hitting an error:
NotImplementedError: Batching rule for 'celerite2_solve_lower_jvp' not implemented
Is that what you're referring to in this issue? Happy to help contribute if you give me some pointers on how to get started.
Ah I see this was because I left my NUTS(...,forward_mode_differentiation=True), sorry for the noise!