pymc icon indicating copy to clipboard operation
pymc copied to clipboard

`Interpolated` not supported in JAX/Numba backends

Open FBruzzesi opened this issue 2 years ago • 4 comments

Describe the issue:

I am raising the issue to create awareness, as I wasn't able to find anything about this error.

While trying to replicate the updating priors notebook, I noticed that if we switch sampler to numpyro, then the following error is raised

NotImplementedError: No JAX conversion for the given Op: SplineWrapper{spline=}

Reproduceable code example:

for _ in range(10):
    ...
    with Model():
        ...
        trace = sample(1000, nuts_sampler = "numpyro", nuts_sampler_kwargs = {"chain_method": "parallel"})
        traces.append(trace)

Error message:

NotImplementedError: No JAX conversion for the given Op: SplineWrapper{spline=}

PyMC version information:

pymc version 5.6.1 numpyro version 0.12.1 jax version 0.4.14

FBruzzesi avatar Jul 30 '23 08:07 FBruzzesi

Welcome Banner :tada: Welcome to PyMC! :tada: We're really excited to have your input into the project! :sparkling_heart:
If you haven't done so already, please make sure you check out our Contributing Guidelines and Code of Conduct.

welcome[bot] avatar Jul 30 '23 08:07 welcome[bot]

Interpolated uses a custom Op (SplineWrapper) that doesn't yet have a JAX (or Numba) implementations.

ricardoV94 avatar Jul 30 '23 11:07 ricardoV94

Hello FBruzzesi, did you find any alternative solution for this problem?

leotercas avatar Jul 11 '24 07:07 leotercas

Hey @leotercas , not really no 😭

FBruzzesi avatar Jul 12 '24 14:07 FBruzzesi