nutpie
nutpie copied to clipboard
nutpie worker panic under jax backend
I'm running into jax backend issues when running a model that samples without error under the default backend. Switching to the jax backend via:
pm.sample(1000, tune=1000, chains=2, nuts_sampler='nutpie', nuts_sampler_kwargs={'backend': 'jax'}, random_seed=RANDOM_SEED)
results in the following panic in a nutpie thread:
thread 'nutpie-worker-1' panicked at /home/runner/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/nuts-rs-0.15.0/src/sampler.rs:635:18:
Could not send sampling results to main thread.: SendError { .. }
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[67], line 32
28 p = pm.Deterministic('p', pm.math.invlogit(logit_p), dims='obs')
30 y = pm.Binomial('y', n=pa, p=p, observed=hr)
---> 32 gp_covariate_trace = pm.sample(1000, tune=1000, chains=2, nuts_sampler='nutpie', nuts_sampler_kwargs={'backend': 'jax'}, random_seed=RANDOM_SEED)
File ~/repos/field_of_play/.pixi/envs/default/lib/python3.12/site-packages/pymc/sampling/mcmc.py:809, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
804 raise ValueError(
805 "Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability."
806 )
808 with joined_blas_limiter():
--> 809 return _sample_external_nuts(
810 sampler=nuts_sampler,
811 draws=draws,
812 tune=tune,
813 chains=chains,
814 target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
815 random_seed=random_seed,
816 initvals=initvals,
817 model=model,
818 var_names=var_names,
819 progressbar=progress_bool,
820 idata_kwargs=idata_kwargs,
821 compute_convergence_checks=compute_convergence_checks,
822 nuts_sampler_kwargs=nuts_sampler_kwargs,
823 **kwargs,
824 )
826 if exclusive_nuts and not provided_steps:
827 # Special path for NUTS initialization
828 if "nuts" in kwargs:
File ~/repos/field_of_play/.pixi/envs/default/lib/python3.12/site-packages/pymc/sampling/mcmc.py:349, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
344 compiled_model = nutpie.compile_pymc_model(
345 model,
346 **compile_kwargs,
347 )
348 t_start = time.time()
--> 349 idata = nutpie.sample(
350 compiled_model,
351 draws=draws,
352 tune=tune,
353 chains=chains,
354 target_accept=target_accept,
355 seed=_get_seeds_per_chain(random_seed, 1)[0],
356 progress_bar=progressbar,
357 **nuts_sampler_kwargs,
358 )
359 t_sample = time.time() - t_start
360 # Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed
361 # gather observed and constant data as nutpie.sample() has no access to the PyMC model
File ~/repos/field_of_play/.pixi/envs/default/lib/python3.12/site-packages/nutpie/sample.py:654, in sample(compiled_model, draws, tune, chains, cores, seed, save_warmup, progress_bar, low_rank_modified_mass_matrix, transform_adapt, init_mean, return_raw_trace, blocking, progress_template, progress_style, progress_rate, **kwargs)
651 return sampler
653 try:
--> 654 result = sampler.wait()
655 except KeyboardInterrupt:
656 result = sampler.abort()
File ~/repos/field_of_play/.pixi/envs/default/lib/python3.12/site-packages/nutpie/sample.py:388, in _BackgroundSampler.wait(self, timeout)
378 def wait(self, *, timeout=None):
379 """Wait until sampling is finished and return the trace.
380
381 KeyboardInterrupt will lead to interrupt the waiting.
(...)
386 This resumes the sampler in case it had been paused.
387 """
--> 388 self._sampler.wait(timeout)
389 results = self._sampler.extract_results()
390 return self._extract(results)
RuntimeError: All initialization points failed
Caused by:
Logp function returned error: PyError(PyErr { type: <class 'AttributeError'>, value: AttributeError("module 'jax.lax' has no attribute 'mul_without_zeros'"), traceback: Some(<traceback object at 0x7f813ad75080>) })
Running on the following environment:
Python implementation: CPython
Python version : 3.12.8
IPython version : 8.32.0
numpy : 1.26.4
scipy : 1.12.0
pymc : 5.20.1
preliz : 0.15.0
nutpie : 0.14.2
pandas : 2.2.3
pytensor : 2.27.1
matplotlib: 3.10.0
plotly : 6.0.0
polars : 1.24.0
arviz : 0.20.0
Sounds like pytensor might be doing something strange here? Maybe somewhat related to this? https://github.com/pymc-devs/pytensor/issues/526
Does this also happen with gradient_backend="jax"?
Yeah, that fixes it. Is there any scenario where you'd want both backends not to be 'jax'?
I've seen several examples where the pytensor gradient was significantly faster. Jax also sometimes has extra nan issues, that pytensor avoids with rewrites ( https://docs.jax.dev/en/latest/faq.html#gradients-contain-nan-where-using-where ) If you can figure out a reproducer, a pytensor issue would be nice :-)
On Sat, 8 Mar 2025, 23:45 Chris Fonnesbeck, @.***> wrote:
Yeah, that fixes it. Is there any scenario where you'd want both backends not to be 'jax'?
— Reply to this email directly, view it on GitHub https://github.com/pymc-devs/nutpie/issues/188#issuecomment-2708518829, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAOLSHONAPWFS77ZLYRJABD2TNXHLAVCNFSM6AAAAABYTMZA6SVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDOMBYGUYTQOBSHE . You are receiving this because you commented.Message ID: @.***> [image: fonnesbeck]fonnesbeck left a comment (pymc-devs/nutpie#188) https://github.com/pymc-devs/nutpie/issues/188#issuecomment-2708518829
Yeah, that fixes it. Is there any scenario where you'd want both backends not to be 'jax'?
— Reply to this email directly, view it on GitHub https://github.com/pymc-devs/nutpie/issues/188#issuecomment-2708518829, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAOLSHONAPWFS77ZLYRJABD2TNXHLAVCNFSM6AAAAABYTMZA6SVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDOMBYGUYTQOBSHE . You are receiving this because you commented.Message ID: @.***>
Not sure if this is the same issue, or just the same error; certainly changing the gradient backend to jax doesn't solve this one:
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import nutpie
import os
def create_array_with_random_nans(shape, nan_fraction=0.1):
if isinstance(shape, int):
shape = (shape,)
n_total = np.prod(shape)
n_nan = int(n_total * nan_fraction)
arr = np.random.rand(n_total)
nan_indices = np.random.choice(n_total, n_nan, replace=False)
arr[nan_indices] = np.nan
return arr.reshape(shape)
n_predictors = 3
n_cases = 100
x_np = create_array_with_random_nans((n_cases,n_predictors), nan_fraction=0.1)
y_np = create_array_with_random_nans(n_cases, nan_fraction=0.1)
with pm.Model() as model:
mu = pm.Normal("mu", mu=0, sigma=1, shape=n_predictors)
chol, corr, sds = pm.LKJCholeskyCov(name = "chol", n = n_predictors, eta = 1, sd_dist = pm.Weibull.dist(2,1))
x = pm.MvNormal("x", mu = mu, chol = chol, observed = x_np)
x_Q, x_R = pt.nlinalg.qr(x)
beta = pm.Normal("beta", mu=0, sigma=1, shape=n_predictors)
sigma = pm.HalfNormal("sigma", sigma=1)
y = pm.Normal(
"y"
# , mu = pm.math.dot(x, beta) # This samples ok
, mu = pm.math.dot(x_Q, beta) # This doesn't sample
, sigma = sigma
, observed = y_np
)
print('Compiling...')
compiled_model = nutpie.compile_pymc_model(
model
, backend = 'jax'
, gradiend_backend = 'jax'
)
# Warning during compilation:
'''
/home/mike/.pyenv/versions/3.13.2/lib/python3.13/functools.py:934: UserWarning: Skipping `CheckAndRaise` Op (assertion: Could not broadcast dimensions. Broadcasting is only allowed along axes that have a statically known length 1. Use `specify_broadcastable` to inform PyTensor of a known shape.) as JAX tracing would remove it.
return dispatch(args[0].__class__)(*args, **kw)
'''
# set the RUST_BACKTRACE=1 environment variable to see the backtrace
os.environ['RUST_BACKTRACE'] = 'full'
print('Sampling...')
trace = nutpie.sample(compiled_model)
# error during sampling:
'''
Traceback (most recent call last):
File "REDACTED", line 51, in <module>
trace = nutpie.sample(compiled_model)
File "/home/mike/.pyenv/versions/3.13.2/lib/python3.13/site-packages/nutpie/sample.py", line 654, in sample
result = sampler.wait()
File "/home/mike/.pyenv/versions/3.13.2/lib/python3.13/site-packages/nutpie/sample.py", line 388, in wait
self._sampler.wait(timeout)
~~~~~~~~~~~~~~~~~~^^^^^^^^^
RuntimeError: All initialization points failed
Caused by:
Logp function returned error: PyError(PyErr { type: <class 'TypeError'>, value: TypeError('Shapes must be 1D sequences of concrete values of integer type, got (3, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace>).\nIf using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.\nThe error occurred while tracing the function jax_funcified_fgraph at /tmp/tmpqsl222qm:1 for jit. This value became a tracer due to JAX operations on these lines:\n\n operation a:bool[] = lt b c\n from line /tmp/tmpqsl222qm:275:26 (jax_funcified_fgraph)\n\n operation a:i64[] = pjit[\n name=_where\n jaxpr={ lambda ; b:bool[] c:i64[] d:i64[]. let\n e:i64[] = select_n b d c\n in (e,) }\n] f g h\n from line /tmp/tmpqsl222qm:277:26 (jax_funcified_fgraph)\n\n operation a:i64[] = sub b c\n from line /tmp/tmpqsl222qm:279:26 (jax_funcified_fgraph)'), traceback: Some(<traceback object at 0x7eb6a82a3580>) })
Stack backtrace:
0: <unknown>
1: <unknown>
2: <unknown>
3: <unknown>
4: <unknown>
5: <unknown>
6: <unknown>
7: start_thread
at ./nptl/pthread_create.c:447:8
8: clone3
at ./misc/../sysdeps/unix/sysv/linux/x86_64/clone3.S:78
thread 'nutpie-worker-2' panicked at /home/runner/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/nuts-rs-0.15.0/src/sampler.rs:635:18:
Could not send sampling results to main thread.: SendError { .. }
stack backtrace:
0: 0x7eb7651553e9 - <unknown>
1: 0x7eb7650930e3 - <unknown>
2: 0x7eb765154982 - <unknown>
3: 0x7eb765155243 - <unknown>
4: 0x7eb7651545f6 - <unknown>
5: 0x7eb76518d8f8 - <unknown>
6: 0x7eb76518d859 - <unknown>
7: 0x7eb76518e00c - <unknown>
8: 0x7eb765091e2f - <unknown>
9: 0x7eb765097845 - <unknown>
10: 0x7eb764f7f2c3 - <unknown>
11: 0x7eb76514a473 - <unknown>
12: 0x7eb765149a6c - <unknown>
13: 0x7eb76514c57e - <unknown>
14: 0x7eb76518e83b - <unknown>
15: 0x7eb7d329caa4 - start_thread
at ./nptl/pthread_create.c:447:8
16: 0x7eb7d3329c3c - clone3
at ./misc/../sysdeps/unix/sysv/linux/x86_64/clone3.S:78
thread 'nutpie-worker-5' panicked at /home/runner/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/nuts-rs-0.15.0/src/sampler.rs:635:18:
Could not send sampling results to main thread.: SendError { .. }
stack backtrace:
0: 0x7eb7651553e9 - <unknown>
1: 0x7eb7650930e3 - <unknown>
2: 0x7eb765154982 - <unknown>
3: 0x7eb765155243 - <unknown>
4: 0x7eb7651545f6 - <unknown>
5: 0x7eb76518d8f8 - <unknown>
6: 0x7eb76518d859 - <unknown>
7: 0x7eb76518e00c - <unknown>
8: 0x7eb765091e2f - <unknown>
9: 0x7eb765097845 - <unknown>
10: 0x7eb764f7f2c3 - <unknown>
11: 0x7eb76514a473 - <unknown>
12: 0x7eb765149a6c - <unknown>
13: 0x7eb76514c57e - <unknown>
14: 0x7eb76518e83b - <unknown>
15: 0x7eb7d329caa4 - start_thread
at ./nptl/pthread_create.c:447:8
16: 0x7eb7d3329c3c - clone3
at ./misc/../sysdeps/unix/sysv/linux/x86_64/clone3.S:78
thread 'nutpie-worker-0' panicked at /home/runner/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/nuts-rs-0.15.0/src/sampler.rs:635:18:
Could not send sampling results to main thread.: SendError { .. }
stack backtrace:
0: 0x7eb7651553e9 - <unknown>
1: 0x7eb7650930e3 - <unknown>
2: 0x7eb765154982 - <unknown>
3: 0x7eb765155243 - <unknown>
4: 0x7eb7651545f6 - <unknown>
5: 0x7eb76518d8f8 - <unknown>
6: 0x7eb76518d859 - <unknown>
7: 0x7eb76518e00c - <unknown>
8: 0x7eb765091e2f - <unknown>
9: 0x7eb765097845 - <unknown>
10: 0x7eb764f7f2c3 - <unknown>
11: 0x7eb76514a473 - <unknown>
12: 0x7eb765149a6c - <unknown>
13: 0x7eb76514c57e - <unknown>
14: 0x7eb76518e83b - <unknown>
15: 0x7eb7d329caa4 - start_thread
at ./nptl/pthread_create.c:447:8
16: 0x7eb7d3329c3c - clone3
at ./misc/../sysdeps/unix/sysv/linux/x86_64/clone3.S:78
thread 'nutpie-worker-6' panicked at /home/runner/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/nuts-rs-0.15.0/src/sampler.rs:635:18:
Could not send sampling results to main thread.: SendError { .. }
stack backtrace:
0: 0x7eb7651553e9 - <unknown>
1: 0x7eb7650930e3 - <unknown>
2: 0x7eb765154982 - <unknown>
3: 0x7eb765155243 - <unknown>
4: 0x7eb7651545f6 - <unknown>
5: 0x7eb76518d8f8 - <unknown>
6: 0x7eb76518d859 - <unknown>
7: 0x7eb76518e00c - <unknown>
8: 0x7eb765091e2f - <unknown>
9: 0x7eb765097845 - <unknown>
10: 0x7eb764f7f2c3 - <unknown>
11: 0x7eb76514a473 - <unknown>
12: 0x7eb765149a6c - <unknown>
13: 0x7eb76514c57e - <unknown>
14: 0x7eb76518e83b - <unknown>
15: 0x7eb7d329caa4 - start_thread
at ./nptl/pthread_create.c:447:8
16: 0x7eb7d3329c3c - clone3
at ./misc/../sysdeps/unix/sysv/linux/x86_64/clone3.S:78
thread 'nutpie-worker-1' panicked at /home/runner/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/nuts-rs-0.15.0/src/sampler.rs:635:18:
Could not send sampling results to main thread.: SendError { .. }
stack backtrace:
0: 0x7eb7651553e9 - <unknown>
1: 0x7eb7650930e3 - <unknown>
2: 0x7eb765154982 - <unknown>
3: 0x7eb765155243 - <unknown>
4: 0x7eb7651545f6 - <unknown>
5: 0x7eb76518d8f8 - <unknown>
6: 0x7eb76518d859 - <unknown>
7: 0x7eb76518e00c - <unknown>
8: 0x7eb765091e2f - <unknown>
9: 0x7eb765097845 - <unknown>
10: 0x7eb764f7f2c3 - <unknown>
11: 0x7eb76514a473 - <unknown>
12: 0x7eb765149a6c - <unknown>
13: 0x7eb76514c57e - <unknown>
14: 0x7eb76518e83b - <unknown>
15: 0x7eb7d329caa4 - start_thread
at ./nptl/pthread_create.c:447:8
16: 0x7eb7d3329c3c - clone3
at ./misc/../sysdeps/unix/sysv/linux/x86_64/clone3.S:78
'''