pymc
pymc copied to clipboard
BUG: VI can't be used with Jax
Describe the issue:
Fails due to missing jax support for VI
Reproduceable code example:
pm.fit(
fn_kwargs=dict(mode="JAX"),
)
### Error message:
_No response_
### PyMC version information:
5.10.3
### Context for the issue:
_No response_