nutpie icon indicating copy to clipboard operation
nutpie copied to clipboard

nutpie worker panic under jax backend

Open fonnesbeck opened this issue 9 months ago • 4 comments

I'm running into jax backend issues when running a model that samples without error under the default backend. Switching to the jax backend via:

pm.sample(1000, tune=1000, chains=2, nuts_sampler='nutpie', nuts_sampler_kwargs={'backend': 'jax'}, random_seed=RANDOM_SEED)

results in the following panic in a nutpie thread:

thread 'nutpie-worker-1' panicked at /home/runner/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/nuts-rs-0.15.0/src/sampler.rs:635:18:
Could not send sampling results to main thread.: SendError { .. }
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[67], line 32
     28 p = pm.Deterministic('p', pm.math.invlogit(logit_p), dims='obs')
     30 y = pm.Binomial('y', n=pa, p=p, observed=hr)
---> 32 gp_covariate_trace = pm.sample(1000, tune=1000, chains=2, nuts_sampler='nutpie', nuts_sampler_kwargs={'backend': 'jax'}, random_seed=RANDOM_SEED)

File ~/repos/field_of_play/.pixi/envs/default/lib/python3.12/site-packages/pymc/sampling/mcmc.py:809, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
    804         raise ValueError(
    805             "Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability."
    806         )
    808     with joined_blas_limiter():
--> 809         return _sample_external_nuts(
    810             sampler=nuts_sampler,
    811             draws=draws,
    812             tune=tune,
    813             chains=chains,
    814             target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    815             random_seed=random_seed,
    816             initvals=initvals,
    817             model=model,
    818             var_names=var_names,
    819             progressbar=progress_bool,
    820             idata_kwargs=idata_kwargs,
    821             compute_convergence_checks=compute_convergence_checks,
    822             nuts_sampler_kwargs=nuts_sampler_kwargs,
    823             **kwargs,
    824         )
    826 if exclusive_nuts and not provided_steps:
    827     # Special path for NUTS initialization
    828     if "nuts" in kwargs:

File ~/repos/field_of_play/.pixi/envs/default/lib/python3.12/site-packages/pymc/sampling/mcmc.py:349, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
    344 compiled_model = nutpie.compile_pymc_model(
    345     model,
    346     **compile_kwargs,
    347 )
    348 t_start = time.time()
--> 349 idata = nutpie.sample(
    350     compiled_model,
    351     draws=draws,
    352     tune=tune,
    353     chains=chains,
    354     target_accept=target_accept,
    355     seed=_get_seeds_per_chain(random_seed, 1)[0],
    356     progress_bar=progressbar,
    357     **nuts_sampler_kwargs,
    358 )
    359 t_sample = time.time() - t_start
    360 # Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed
    361 # gather observed and constant data as nutpie.sample() has no access to the PyMC model

File ~/repos/field_of_play/.pixi/envs/default/lib/python3.12/site-packages/nutpie/sample.py:654, in sample(compiled_model, draws, tune, chains, cores, seed, save_warmup, progress_bar, low_rank_modified_mass_matrix, transform_adapt, init_mean, return_raw_trace, blocking, progress_template, progress_style, progress_rate, **kwargs)
    651     return sampler
    653 try:
--> 654     result = sampler.wait()
    655 except KeyboardInterrupt:
    656     result = sampler.abort()

File ~/repos/field_of_play/.pixi/envs/default/lib/python3.12/site-packages/nutpie/sample.py:388, in _BackgroundSampler.wait(self, timeout)
    378 def wait(self, *, timeout=None):
    379     """Wait until sampling is finished and return the trace.
    380 
    381     KeyboardInterrupt will lead to interrupt the waiting.
   (...)
    386     This resumes the sampler in case it had been paused.
    387     """
--> 388     self._sampler.wait(timeout)
    389     results = self._sampler.extract_results()
    390     return self._extract(results)

RuntimeError: All initialization points failed

Caused by:
    Logp function returned error: PyError(PyErr { type: <class 'AttributeError'>, value: AttributeError("module 'jax.lax' has no attribute 'mul_without_zeros'"), traceback: Some(<traceback object at 0x7f813ad75080>) })

Running on the following environment:

Python implementation: CPython
Python version       : 3.12.8
IPython version      : 8.32.0

numpy     : 1.26.4
scipy     : 1.12.0
pymc      : 5.20.1
preliz    : 0.15.0
nutpie    : 0.14.2
pandas    : 2.2.3
pytensor  : 2.27.1
matplotlib: 3.10.0
plotly    : 6.0.0
polars    : 1.24.0
arviz     : 0.20.0

fonnesbeck avatar Mar 08 '25 20:03 fonnesbeck

Sounds like pytensor might be doing something strange here? Maybe somewhat related to this? https://github.com/pymc-devs/pytensor/issues/526 Does this also happen with gradient_backend="jax"?

aseyboldt avatar Mar 08 '25 20:03 aseyboldt

Yeah, that fixes it. Is there any scenario where you'd want both backends not to be 'jax'?

fonnesbeck avatar Mar 08 '25 22:03 fonnesbeck

I've seen several examples where the pytensor gradient was significantly faster. Jax also sometimes has extra nan issues, that pytensor avoids with rewrites ( https://docs.jax.dev/en/latest/faq.html#gradients-contain-nan-where-using-where ) If you can figure out a reproducer, a pytensor issue would be nice :-)

On Sat, 8 Mar 2025, 23:45 Chris Fonnesbeck, @.***> wrote:

Yeah, that fixes it. Is there any scenario where you'd want both backends not to be 'jax'?

— Reply to this email directly, view it on GitHub https://github.com/pymc-devs/nutpie/issues/188#issuecomment-2708518829, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAOLSHONAPWFS77ZLYRJABD2TNXHLAVCNFSM6AAAAABYTMZA6SVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDOMBYGUYTQOBSHE . You are receiving this because you commented.Message ID: @.***> [image: fonnesbeck]fonnesbeck left a comment (pymc-devs/nutpie#188) https://github.com/pymc-devs/nutpie/issues/188#issuecomment-2708518829

Yeah, that fixes it. Is there any scenario where you'd want both backends not to be 'jax'?

— Reply to this email directly, view it on GitHub https://github.com/pymc-devs/nutpie/issues/188#issuecomment-2708518829, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAOLSHONAPWFS77ZLYRJABD2TNXHLAVCNFSM6AAAAABYTMZA6SVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDOMBYGUYTQOBSHE . You are receiving this because you commented.Message ID: @.***>

aseyboldt avatar Mar 08 '25 22:03 aseyboldt

Not sure if this is the same issue, or just the same error; certainly changing the gradient backend to jax doesn't solve this one:

import numpy as np
import pymc as pm
import pytensor.tensor as pt
import nutpie
import os

def create_array_with_random_nans(shape, nan_fraction=0.1):
	if isinstance(shape, int):
		shape = (shape,)
	n_total = np.prod(shape)
	n_nan = int(n_total * nan_fraction)
	arr = np.random.rand(n_total)
	nan_indices = np.random.choice(n_total, n_nan, replace=False)
	arr[nan_indices] = np.nan
	return arr.reshape(shape)


n_predictors = 3
n_cases = 100
x_np = create_array_with_random_nans((n_cases,n_predictors), nan_fraction=0.1)
y_np = create_array_with_random_nans(n_cases, nan_fraction=0.1)


with pm.Model() as model:
	mu = pm.Normal("mu", mu=0, sigma=1, shape=n_predictors)
	chol, corr, sds = pm.LKJCholeskyCov(name = "chol", n = n_predictors, eta = 1, sd_dist = pm.Weibull.dist(2,1))
	x = pm.MvNormal("x", mu = mu, chol = chol, observed = x_np)
	x_Q, x_R = pt.nlinalg.qr(x)
	beta = pm.Normal("beta", mu=0, sigma=1, shape=n_predictors)
	sigma = pm.HalfNormal("sigma", sigma=1)
	y = pm.Normal(
		"y"
		# , mu = pm.math.dot(x, beta) # This samples ok
		, mu = pm.math.dot(x_Q, beta) # This doesn't sample
		, sigma = sigma
		, observed = y_np
	)


print('Compiling...')
compiled_model = nutpie.compile_pymc_model(
	model
	, backend = 'jax'
	, gradiend_backend = 'jax'
)

# Warning during compilation:
'''
/home/mike/.pyenv/versions/3.13.2/lib/python3.13/functools.py:934: UserWarning: Skipping `CheckAndRaise` Op (assertion: Could not broadcast dimensions. Broadcasting is only allowed along axes that have a statically known length 1. Use `specify_broadcastable` to inform PyTensor of a known shape.) as JAX tracing would remove it.
  return dispatch(args[0].__class__)(*args, **kw)
'''

# set the RUST_BACKTRACE=1 environment variable to see the backtrace
os.environ['RUST_BACKTRACE'] = 'full'

print('Sampling...')
trace = nutpie.sample(compiled_model)

# error during sampling:
'''
Traceback (most recent call last):
  File "REDACTED", line 51, in <module>
    trace = nutpie.sample(compiled_model)
  File "/home/mike/.pyenv/versions/3.13.2/lib/python3.13/site-packages/nutpie/sample.py", line 654, in sample
    result = sampler.wait()
  File "/home/mike/.pyenv/versions/3.13.2/lib/python3.13/site-packages/nutpie/sample.py", line 388, in wait
    self._sampler.wait(timeout)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^
RuntimeError: All initialization points failed

Caused by:
    Logp function returned error: PyError(PyErr { type: <class 'TypeError'>, value: TypeError('Shapes must be 1D sequences of concrete values of integer type, got (3, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace>).\nIf using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.\nThe error occurred while tracing the function jax_funcified_fgraph at /tmp/tmpqsl222qm:1 for jit. This value became a tracer due to JAX operations on these lines:\n\n  operation a:bool[] = lt b c\n    from line /tmp/tmpqsl222qm:275:26 (jax_funcified_fgraph)\n\n  operation a:i64[] = pjit[\n  name=_where\n  jaxpr={ lambda ; b:bool[] c:i64[] d:i64[]. let\n      e:i64[] = select_n b d c\n    in (e,) }\n] f g h\n    from line /tmp/tmpqsl222qm:277:26 (jax_funcified_fgraph)\n\n  operation a:i64[] = sub b c\n    from line /tmp/tmpqsl222qm:279:26 (jax_funcified_fgraph)'), traceback: Some(<traceback object at 0x7eb6a82a3580>) })

Stack backtrace:
   0: <unknown>
   1: <unknown>
   2: <unknown>
   3: <unknown>
   4: <unknown>
   5: <unknown>
   6: <unknown>
   7: start_thread
             at ./nptl/pthread_create.c:447:8
   8: clone3
             at ./misc/../sysdeps/unix/sysv/linux/x86_64/clone3.S:78

thread 'nutpie-worker-2' panicked at /home/runner/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/nuts-rs-0.15.0/src/sampler.rs:635:18:
Could not send sampling results to main thread.: SendError { .. }
stack backtrace:
   0:     0x7eb7651553e9 - <unknown>
   1:     0x7eb7650930e3 - <unknown>
   2:     0x7eb765154982 - <unknown>
   3:     0x7eb765155243 - <unknown>
   4:     0x7eb7651545f6 - <unknown>
   5:     0x7eb76518d8f8 - <unknown>
   6:     0x7eb76518d859 - <unknown>
   7:     0x7eb76518e00c - <unknown>
   8:     0x7eb765091e2f - <unknown>
   9:     0x7eb765097845 - <unknown>
  10:     0x7eb764f7f2c3 - <unknown>
  11:     0x7eb76514a473 - <unknown>
  12:     0x7eb765149a6c - <unknown>
  13:     0x7eb76514c57e - <unknown>
  14:     0x7eb76518e83b - <unknown>
  15:     0x7eb7d329caa4 - start_thread
                               at ./nptl/pthread_create.c:447:8
  16:     0x7eb7d3329c3c - clone3
                               at ./misc/../sysdeps/unix/sysv/linux/x86_64/clone3.S:78

thread 'nutpie-worker-5' panicked at /home/runner/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/nuts-rs-0.15.0/src/sampler.rs:635:18:
Could not send sampling results to main thread.: SendError { .. }
stack backtrace:
   0:     0x7eb7651553e9 - <unknown>
   1:     0x7eb7650930e3 - <unknown>
   2:     0x7eb765154982 - <unknown>
   3:     0x7eb765155243 - <unknown>
   4:     0x7eb7651545f6 - <unknown>
   5:     0x7eb76518d8f8 - <unknown>
   6:     0x7eb76518d859 - <unknown>
   7:     0x7eb76518e00c - <unknown>
   8:     0x7eb765091e2f - <unknown>
   9:     0x7eb765097845 - <unknown>
  10:     0x7eb764f7f2c3 - <unknown>
  11:     0x7eb76514a473 - <unknown>
  12:     0x7eb765149a6c - <unknown>
  13:     0x7eb76514c57e - <unknown>
  14:     0x7eb76518e83b - <unknown>
  15:     0x7eb7d329caa4 - start_thread
                               at ./nptl/pthread_create.c:447:8
  16:     0x7eb7d3329c3c - clone3
                               at ./misc/../sysdeps/unix/sysv/linux/x86_64/clone3.S:78

thread 'nutpie-worker-0' panicked at /home/runner/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/nuts-rs-0.15.0/src/sampler.rs:635:18:
Could not send sampling results to main thread.: SendError { .. }
stack backtrace:
   0:     0x7eb7651553e9 - <unknown>
   1:     0x7eb7650930e3 - <unknown>
   2:     0x7eb765154982 - <unknown>
   3:     0x7eb765155243 - <unknown>
   4:     0x7eb7651545f6 - <unknown>
   5:     0x7eb76518d8f8 - <unknown>
   6:     0x7eb76518d859 - <unknown>
   7:     0x7eb76518e00c - <unknown>
   8:     0x7eb765091e2f - <unknown>
   9:     0x7eb765097845 - <unknown>
  10:     0x7eb764f7f2c3 - <unknown>
  11:     0x7eb76514a473 - <unknown>
  12:     0x7eb765149a6c - <unknown>
  13:     0x7eb76514c57e - <unknown>
  14:     0x7eb76518e83b - <unknown>
  15:     0x7eb7d329caa4 - start_thread
                               at ./nptl/pthread_create.c:447:8
  16:     0x7eb7d3329c3c - clone3
                               at ./misc/../sysdeps/unix/sysv/linux/x86_64/clone3.S:78

thread 'nutpie-worker-6' panicked at /home/runner/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/nuts-rs-0.15.0/src/sampler.rs:635:18:
Could not send sampling results to main thread.: SendError { .. }
stack backtrace:
   0:     0x7eb7651553e9 - <unknown>
   1:     0x7eb7650930e3 - <unknown>
   2:     0x7eb765154982 - <unknown>
   3:     0x7eb765155243 - <unknown>
   4:     0x7eb7651545f6 - <unknown>
   5:     0x7eb76518d8f8 - <unknown>
   6:     0x7eb76518d859 - <unknown>
   7:     0x7eb76518e00c - <unknown>
   8:     0x7eb765091e2f - <unknown>
   9:     0x7eb765097845 - <unknown>
  10:     0x7eb764f7f2c3 - <unknown>
  11:     0x7eb76514a473 - <unknown>
  12:     0x7eb765149a6c - <unknown>
  13:     0x7eb76514c57e - <unknown>
  14:     0x7eb76518e83b - <unknown>
  15:     0x7eb7d329caa4 - start_thread
                               at ./nptl/pthread_create.c:447:8
  16:     0x7eb7d3329c3c - clone3
                               at ./misc/../sysdeps/unix/sysv/linux/x86_64/clone3.S:78

thread 'nutpie-worker-1' panicked at /home/runner/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/nuts-rs-0.15.0/src/sampler.rs:635:18:
Could not send sampling results to main thread.: SendError { .. }
stack backtrace:
   0:     0x7eb7651553e9 - <unknown>
   1:     0x7eb7650930e3 - <unknown>
   2:     0x7eb765154982 - <unknown>
   3:     0x7eb765155243 - <unknown>
   4:     0x7eb7651545f6 - <unknown>
   5:     0x7eb76518d8f8 - <unknown>
   6:     0x7eb76518d859 - <unknown>
   7:     0x7eb76518e00c - <unknown>
   8:     0x7eb765091e2f - <unknown>
   9:     0x7eb765097845 - <unknown>
  10:     0x7eb764f7f2c3 - <unknown>
  11:     0x7eb76514a473 - <unknown>
  12:     0x7eb765149a6c - <unknown>
  13:     0x7eb76514c57e - <unknown>
  14:     0x7eb76518e83b - <unknown>
  15:     0x7eb7d329caa4 - start_thread
                               at ./nptl/pthread_create.c:447:8
  16:     0x7eb7d3329c3c - clone3
                               at ./misc/../sysdeps/unix/sysv/linux/x86_64/clone3.S:78
'''

mike-lawrence avatar Apr 04 '25 15:04 mike-lawrence