MLIR translation rule for primitive 'celerite2_factor' not found for platform cuda
Running the following test script from the docs:
from jax import config
config.update("jax_enable_x64", True)
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import celerite2.jax
from celerite2.jax import terms as jax_terms
import numpy as np
np.random.seed(42)
t = np.sort(
np.append(
np.random.uniform(0, 3.8, 57),
np.random.uniform(5.5, 10, 68),
)
) # The input coordinates must be sorted
yerr = np.random.uniform(0.08, 0.22, len(t))
y = (
0.2 * (t - 5)
+ np.sin(3 * t + 0.1 * (t - 5) ** 2)
+ yerr * np.random.randn(len(t))
)
true_t = np.linspace(0, 10, 500)
true_y = 0.2 * (true_t - 5) + np.sin(3 * true_t + 0.1 * (true_t - 5) ** 2)
prior_sigma = 2.0
def numpyro_model(t, yerr, y=None):
mean = numpyro.sample("mean", dist.Normal(0.0, prior_sigma))
log_jitter = numpyro.sample("log_jitter", dist.Normal(0.0, prior_sigma))
log_sigma1 = numpyro.sample("log_sigma1", dist.Normal(0.0, prior_sigma))
log_rho1 = numpyro.sample("log_rho1", dist.Normal(0.0, prior_sigma))
log_tau = numpyro.sample("log_tau", dist.Normal(0.0, prior_sigma))
term1 = jax_terms.SHOTerm(
sigma=jnp.exp(log_sigma1), rho=jnp.exp(log_rho1), tau=jnp.exp(log_tau)
)
log_sigma2 = numpyro.sample("log_sigma2", dist.Normal(0.0, prior_sigma))
log_rho2 = numpyro.sample("log_rho2", dist.Normal(0.0, prior_sigma))
term2 = jax_terms.SHOTerm(
sigma=jnp.exp(log_sigma2), rho=jnp.exp(log_rho2), Q=0.25
)
kernel = term1 + term2
gp = celerite2.jax.GaussianProcess(kernel, mean=mean)
gp.compute(t, diag=yerr**2 + jnp.exp(log_jitter), check_sorted=False)
numpyro.sample("obs", gp.numpyro_dist(), obs=y)
numpyro.deterministic("psd", kernel.get_psd(omega))
nuts_kernel = NUTS(numpyro_model, dense_mass=True)
mcmc = MCMC(
nuts_kernel,
num_warmup=1000,
num_samples=1000,
num_chains=2,
progress_bar=False,
)
rng_key = random.PRNGKey(34923)
mcmc.run(rng_key, t, yerr, y=y)
Produces the following error:
Traceback (most recent call last):
File "/computefs/scratch/username/mypackage/notebooks/temp/test.py", line 16, in <module>
import matplotlib.pyplot as plt
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/matplotlib/__init__.py", line 264, in <module>
_check_versions()
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/matplotlib/__init__.py", line 258, in _check_versions
module = importlib.import_module(modname)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/ixsoftware/python/3.12.6/install/lib/python3.12/importlib/__init__.py", line 90, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ModuleNotFoundError: No module named 'dateutil'
(.venv_cuda) [username@node10 temp]$ python test.py
Traceback (most recent call last):
File "/computefs/scratch/username/mypackage/notebooks/temp/test.py", line 16, in <module>
import matplotlib.pyplot as plt
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/matplotlib/__init__.py", line 264, in <module>
_check_versions()
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/matplotlib/__init__.py", line 258, in _check_versions
module = importlib.import_module(modname)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/ixsoftware/python/3.12.6/install/lib/python3.12/importlib/__init__.py", line 90, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ModuleNotFoundError: No module named 'dateutil'
(.venv_cuda) [username@node10 temp]$
(.venv_cuda) [username@node10 temp]$ python test.py
/computefs/scratch/username/mypackage/notebooks/temp/test.py:64: UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
mcmc = MCMC(
Traceback (most recent call last):
File "/computefs/scratch/username/mypackage/notebooks/temp/test.py", line 72, in <module>
mcmc.run(rng_key, t, yerr, y=y)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 706, in run
states, last_state = _laxmap(partial_map_fn, map_args)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 177, in _laxmap
ys.append(f(x))
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 465, in _single_chain_mcmc
new_init_state = self.sampler.init(
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/hmc.py", line 749, in init
init_params = self._init_state(
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/hmc.py", line 693, in _init_state
) = initialize_model(
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/util.py", line 688, in initialize_model
) = _get_model_transforms(substituted_model, model_args, model_kwargs)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/util.py", line 482, in _get_model_transforms
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/handlers.py", line 191, in get_trace
self(*args, **kwargs)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/primitives.py", line 121, in __call__
return self.fn(*args, **kwargs)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/primitives.py", line 121, in __call__
return self.fn(*args, **kwargs)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/primitives.py", line 121, in __call__
return self.fn(*args, **kwargs)
File "/computefs/scratch/username/mypackage/notebooks/temp/test.py", line 57, in numpyro_model
gp.compute(t, diag=yerr**2 + jnp.exp(log_jitter), check_sorted=False)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/celerite2/core.py", line 317, in compute
self._do_compute(quiet)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/celerite2/jax/celerite2.py", line 34, in _do_compute
self._d, self._W = ops.factor(
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/celerite2/jax/ops.py", line 39, in factor
d, W, S = factor_p.bind(t, c, a, U, V)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'celerite2_factor' not found for platform cuda
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/computefs/scratch/username/mypackage/notebooks/temp/test.py", line 72, in <module>
mcmc.run(rng_key, t, yerr, y=y)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 706, in run
states, last_state = _laxmap(partial_map_fn, map_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 177, in _laxmap
ys.append(f(x))
^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 465, in _single_chain_mcmc
new_init_state = self.sampler.init(
^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/hmc.py", line 749, in init
init_params = self._init_state(
^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/hmc.py", line 693, in _init_state
) = initialize_model(
^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/util.py", line 688, in initialize_model
) = _get_model_transforms(substituted_model, model_args, model_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/util.py", line 482, in _get_model_transforms
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/handlers.py", line 191, in get_trace
self(*args, **kwargs)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/primitives.py", line 121, in __call__
return self.fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/primitives.py", line 121, in __call__
return self.fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/primitives.py", line 121, in __call__
return self.fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/test.py", line 57, in numpyro_model
gp.compute(t, diag=yerr**2 + jnp.exp(log_jitter), check_sorted=False)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/celerite2/core.py", line 317, in compute
self._do_compute(quiet)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/celerite2/jax/celerite2.py", line 34, in _do_compute
self._d, self._W = ops.factor(
^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/celerite2/jax/ops.py", line 39, in factor
d, W, S = factor_p.bind(t, c, a, U, V)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/core.py", line 438, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/core.py", line 442, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/core.py", line 948, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/dispatch.py", line 90, in apply_primitive
outs = fun(*args)
^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 356, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 189, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **p.params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/core.py", line 2781, in bind
return self.bind_with_trace(top_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/core.py", line 442, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/core.py", line 948, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 1764, in _pjit_call_impl
return xc._xla.pjit(
^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 1739, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 1661, in _pjit_call_impl_python
compiled = _resolve_and_lower(
^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 1628, in _resolve_and_lower
lowered = _pjit_lower(
^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 1780, in _pjit_lower
return _pjit_lower_cached(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 1801, in _pjit_lower_cached
return pxla.lower_sharding_computation(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/profiler.py", line 333, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2232, in lower_sharding_computation
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1952, in _cached_lowering_to_hlo
lowering_result = mlir.lower_jaxpr_to_module(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1152, in lower_jaxpr_to_module
lower_jaxpr_to_fun(
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1610, in lower_jaxpr_to_fun
out_vals, tokens_out = jaxpr_subcomp(
^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1825, in jaxpr_subcomp
ans = lower_per_platform(rule_ctx, str(eqn.primitive),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1914, in lower_per_platform
raise NotImplementedError(
NotImplementedError: MLIR translation rule for primitive 'celerite2_factor' not found for platform cuda
$pip freeze
celerite2==0.3.2
contourpy==1.3.1
cycler==0.12.1
fonttools==4.56.0
jax==0.4.34
jax-cuda12-pjrt==0.4.34
jax-cuda12-plugin==0.4.34
jaxlib==0.4.34
jaxopt==0.8.3
kiwisolver==1.4.8
matplotlib==3.10.0
ml_dtypes==0.5.1
multipledispatch==1.0.0
numpy==2.2.3
numpyro==0.17.0
nvidia-cublas-cu12==12.8.3.14
nvidia-cuda-cupti-cu12==12.8.57
nvidia-cuda-nvcc-cu12==12.8.61
nvidia-cuda-runtime-cu12==12.8.57
nvidia-cudnn-cu12==9.7.1.26
nvidia-cufft-cu12==11.3.3.41
nvidia-cusolver-cu12==11.7.2.55
nvidia-cusparse-cu12==12.5.7.53
nvidia-nccl-cu12==2.25.1
nvidia-nvjitlink-cu12==12.8.61
opt_einsum==3.4.0
pillow==11.1.0
pyparsing==3.2.1
scipy==1.15.2
setuptools==75.8.0
tqdm==4.67.1
Hi! This is working as expected. celerite2.jax doesn't support running on a GPU, it only has CPU implementations.
Thanks @dfm - do you expect there would there be a benefit to running on the GPU for certain datasets? And is there any plan to implement a GPU enabled version....😬?
No and no. The algorithms used in celerite are sequential by nature, and they are not suitable for hardware acceleration. You can run the algorithms on a gpu and they'll be orders of magnitude slower than running on a cpu! It's possible that you could write a version using associative scan that could be appropriately parallelized, but that would be a research project and there are serious numerical issues with the naive implementations.
I'll note that you can run the implementation of these algorithms in the tinygp package on GPU, but it's crazy slow.
OK - thanks @dfm. Not what I was hoping to hear! But good to know.
Can you say more about what you were hoping to achieve here? Why do you need to run celerite2 on GPU? Is it because it's part of a larger model that benefits from using a GPU, or were you hoping for some performance improvement compared to CPU?