jax
jax copied to clipboard
jax-metal: XlaRuntimeError: INTERNAL: Unable to serialize MPS module
Description
Encountered a XLARuntimeError while running a basic numpyro program using jax-metal. The issue arises when I try running the MCMC sampler.
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from jax import random
import jax.numpy as np
J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
def eight_schools(J, sigma, y=None):
mu = numpyro.sample('mu', dist.Normal(0, 5))
tau = numpyro.sample('tau', dist.HalfCauchy(5))
with numpyro.plate('J', J):
theta = numpyro.sample('theta', dist.Normal(mu, tau))
numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)
nuts_kernel = NUTS(eight_schools)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))
Running this gives me the following error:
2024-03-22 13:27:18.342904: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M3 Max
systemMemory: 64.00 GB
maxCacheSize: 24.00 GB
0%| | 0/1500 [00:00<?, ?it/s]
Traceback (most recent call last):
File "[...]/test.py", line 23, in <module>
mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))
File "[...]/numpyro/numpyro/infer/mcmc.py", line 644, in run
states_flat, last_state = partial_map_fn(map_args)
^^^^^^^^^^^^^^^^^^^^^^^^
File "[...]/numpyro/numpyro/infer/mcmc.py", line 450, in _single_chain_mcmc
collect_vals = fori_collect(
^^^^^^^^^^^^^
File "[...]/numpyro/numpyro/util.py", line 367, in fori_collect
vals = jit(_body_fn)(i, vals)
^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Unable to serialize MPS module
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
System info (python version, jaxlib version, accelerator, etc.)
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-22 13:36:51.819062: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M3 Max
systemMemory: 64.00 GB
maxCacheSize: 24.00 GB
jax: 0.4.25
jaxlib: 0.4.23
numpy: 1.26.4
python: 3.11.8 (main, Feb 26 2024, 15:36:12) [Clang 14.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='[...]-MacBook-Pro.local', release='23.4.0', version='Darwin Kernel Version 23.4.0: Wed Feb 21 21:44:54 PST 2024; root:xnu-10063.101.15~2/RELEASE_ARM64_T6031', machine='arm64')
macOS: Sonoma Version 14.4
Will you be able to reproduce the issue with a smaller module?
Hi, yes, I was able to. Here's a more stripped down version - let me know if you require something smaller.
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from jax import random
def model():
mu = numpyro.sample('mu', dist.Normal(0, 5))
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, extra_fields=('potential_energy',))
I set JAX_TRACEBACK_FILTERING=off and here's the error I get:
Unfiltered stack trace
Metal device set to: Apple M3 Max
systemMemory: 64.00 GB
maxCacheSize: 24.00 GB
0%| | 0/1500 [00:00<?, ?it/s]
Traceback (most recent call last):
File "[...]/pyro-svi/sample-numpyro.py", line 12, in <module>
mcmc.run(rng_key, extra_fields=('potential_energy',))
File "[...]/numpyro/numpyro/infer/mcmc.py", line 644, in run
states_flat, last_state = partial_map_fn(map_args)
^^^^^^^^^^^^^^^^^^^^^^^^
File "[...]/numpyro/numpyro/infer/mcmc.py", line 450, in _single_chain_mcmc
collect_vals = fori_collect(
^^^^^^^^^^^^^
File "[...]/numpyro/numpyro/util.py", line 367, in fori_collect
vals = jit(_body_fn)(i, vals)
^^^^^^^^^^^^^^^^^^^^^^
File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/pjit.py", line 248, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
^^^^^^^^^^^^^^^^^^^^
File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/pjit.py", line 143, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/core.py", line 2727, in bind
return self.bind_with_trace(top_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/core.py", line 423, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/core.py", line 913, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/pjit.py", line 1415, in _pjit_call_impl
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, # type: ignore
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/pjit.py", line 1392, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
^^^^^^^^^^^^^^^^^^^^^^^
File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/pjit.py", line 1328, in _pjit_call_impl_python
lowering_parameters=mlir.LoweringParameters()).compile()
^^^^^^^^^
File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2271, in compile
executable = UnloadedMeshExecutable.from_hlo(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2734, in from_hlo
xla_executable, compile_options = _cached_compilation(
^^^^^^^^^^^^^^^^^^^^
File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2591, in _cached_compilation
xla_executable = compiler.compile_or_get_cached(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/compiler.py", line 265, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/compiler.py", line 237, in backend_compile
return backend.compile(built_c, compile_options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Unable to serialize MPS module
+1, Just ran into this same issue while trying to accelerate inference on my own apple silicon. I don't know nearly enough to help debug but am seeing the same issue
jax 0.4.26
jaxlib 0.4.23
numpy 1.26.4
python 3.12.3
@dcalacci there's not much to debug unfortunately: the metal plugin is still experimental and very incomplete, and so you should expect to run into these kinds of issues when using it. My recommendation would be to switch to non-experimental hardware.
Yeah no problem! I understand this is very experimental. Just giving a +1 to the issue as another person playing with these new tools. Thanks for all your hard work!