nutpie
nutpie copied to clipboard
JAX backend fails for a simple `pymc` linear regression model
Minimal example
import time
import arviz as az
import numpy as np
import nutpie
import pandas as pd
import pymc as pm
BETA = [1.0, -1.0, 2.0, -2.0]
SIGMA = 10.0
def generate_data(num_samples: int = 1000) -> pd.DataFrame:
rng = np.random.default_rng(42)
dims = len(BETA)
X = rng.normal(size=(num_samples, dims))
y = X.dot(BETA) + SIGMA * rng.normal(size=num_samples)
frame = pd.DataFrame(data=X, columns=[f"x_{i+1}" for i in range(dims)])
frame["y"] = y
return frame
def make_model(frame: pd.DataFrame) -> pm.Model:
predictors = [col for col in frame.columns if col.startswith("x")]
observation_idx = [i for i in range(len(frame))]
coords = {"observation_idx": observation_idx, "predictors": predictors}
with pm.Model(coords=coords) as model:
# Data
x = pm.Data("x", frame[predictors], dims=["observation_idx", "predictor"])
y = pm.Data("y", frame["y"], dims="observation_idx")
# Population level
beta = pm.Normal("beta", mu=0.0, sigma=1.0, dims="predictor")
sigma = pm.HalfNormal("sigma", sigma=10.0)
# Linear model
mu = (beta * x).sum(axis=-1)
# Likelihood
pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y, shape=mu.shape)
return model
if __name__ == "__main__":
frame = generate_data(num_samples=10_000)
model = make_model(frame)
kwargs = dict(backend="jax", gradient_backend="jax")
t0 = time.time()
trace = nutpie.sample(nutpie.compile_pymc_model(model, **kwargs))
t = time.time() - t0
print(f"Time for nutpie (compiled, {kwargs=}) sampling is {t=:.3f}s.")
summary = az.summary(trace, var_names=["beta", "sigma"], round_to=4).loc[:, ["mean", "hdi_3%", "hdi_97%", "ess_bulk", "ess_tail", "r_hat"]]
print(summary)
Error message
thread 'nutpie-worker-3' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
Traceback (most recent call last):
File ".../linear_regression.py", line 52, in <module>
thread 'nutpie-worker-0' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
trace = nutpie.sample(nutpie.compile_pymc_model(model, **kwargs))
^^thread 'nutpie-worker-6' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
^^thread 'nutpie-worker-1' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
^^^^^^^^^^thread 'nutpie-worker-2' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lib/python3.12/site-packages/nutpie/sample.py", line 636, in sample
result = sampler.wait()
^^^^^^^^^^^^^^
File ".../lib/python3.12/site-packages/nutpie/sample.py", line 388, in wait
self._sampler.wait(timeout)
RuntimeError: All initialization points failed
Caused by:
Logp function returned error: Python error: TypeError: _compile_pymc_model_jax.<locals>.make_logp_func.<locals>.logp() got multiple values for argument 'x'
Sampling with backend="numba" and gradient_backend="pytensor" runs successfully.
Version
# packages in environment at .../.miniconda3/envs/pymc:
#
# Name Version Build Channel
absl-py 2.1.0 pyhd8ed1ab_0 conda-forge
accelerate 1.0.0 pyhd8ed1ab_0 conda-forge
arviz 0.20.0 pyhd8ed1ab_0 conda-forge
atk-1.0 2.38.0 hd03087b_2 conda-forge
aws-c-auth 0.7.31 hc27b277_0 conda-forge
aws-c-cal 0.7.4 h41dd001_1 conda-forge
aws-c-common 0.9.28 hd74edd7_0 conda-forge
aws-c-compression 0.2.19 h41dd001_1 conda-forge
aws-c-event-stream 0.4.3 h40a8fc1_2 conda-forge
aws-c-http 0.8.10 hf5a2c8c_0 conda-forge
aws-c-io 0.14.18 hc3cb426_12 conda-forge
aws-c-mqtt 0.10.7 h3acc7b9_0 conda-forge
aws-c-s3 0.6.6 hd16c091_0 conda-forge
aws-c-sdkutils 0.1.19 h41dd001_3 conda-forge
aws-checksums 0.1.20 h41dd001_0 conda-forge
aws-crt-cpp 0.28.3 h433f80b_6 conda-forge
aws-sdk-cpp 1.11.407 h0455a66_0 conda-forge
azure-core-cpp 1.13.0 hd01fc5c_0 conda-forge
azure-identity-cpp 1.8.0 h13ea094_2 conda-forge
azure-storage-blobs-cpp 12.12.0 hfde595f_0 conda-forge
azure-storage-common-cpp 12.7.0 hcf3b6fd_1 conda-forge
azure-storage-files-datalake-cpp 12.11.0 h082e32e_1 conda-forge
blackjax 1.2.4 pyhd8ed1ab_0 conda-forge
blas 2.124 openblas conda-forge
blas-devel 3.9.0 24_osxarm64_openblas conda-forge
brotli 1.1.0 hd74edd7_2 conda-forge
brotli-bin 1.1.0 hd74edd7_2 conda-forge
brotli-python 1.1.0 py312hde4cb15_2 conda-forge
bzip2 1.0.8 h99b78c6_7 conda-forge
c-ares 1.34.1 hd74edd7_0 conda-forge
c-compiler 1.8.0 h2664225_0 conda-forge
ca-certificates 2024.8.30 hf0a4a13_0 conda-forge
cached-property 1.5.2 hd8ed1ab_1 conda-forge
cached_property 1.5.2 pyha770c72_1 conda-forge
cachetools 5.5.0 pyhd8ed1ab_0 conda-forge
cairo 1.18.0 hb4a6bf7_3 conda-forge
cctools 1010.6 hf67d63f_1 conda-forge
cctools_osx-arm64 1010.6 h4208deb_1 conda-forge
certifi 2024.8.30 pyhd8ed1ab_0 conda-forge
cffi 1.17.1 py312h0fad829_0 conda-forge
charset-normalizer 3.4.0 pyhd8ed1ab_0 conda-forge
chex 0.1.87 pyhd8ed1ab_0 conda-forge
clang 17.0.6 default_h360f5da_7 conda-forge
clang-17 17.0.6 default_h146c034_7 conda-forge
clang_impl_osx-arm64 17.0.6 he47c785_21 conda-forge
clang_osx-arm64 17.0.6 h54d7cd3_21 conda-forge
clangxx 17.0.6 default_h360f5da_7 conda-forge
clangxx_impl_osx-arm64 17.0.6 h50f59cd_21 conda-forge
clangxx_osx-arm64 17.0.6 h54d7cd3_21 conda-forge
cloudpickle 3.0.0 pyhd8ed1ab_0 conda-forge
colorama 0.4.6 pyhd8ed1ab_0 conda-forge
compiler-rt 17.0.6 h856b3c1_2 conda-forge
compiler-rt_osx-arm64 17.0.6 h832e737_2 conda-forge
cons 0.4.6 pyhd8ed1ab_0 conda-forge
contourpy 1.3.0 py312h6142ec9_2 conda-forge
cpython 3.12.7 py312hd8ed1ab_0 conda-forge
cxx-compiler 1.8.0 he8d86c4_0 conda-forge
cycler 0.12.1 pyhd8ed1ab_0 conda-forge
etils 1.9.4 pyhd8ed1ab_0 conda-forge
etuples 0.3.9 pyhd8ed1ab_0 conda-forge
expat 2.6.3 hf9b8971_0 conda-forge
fastprogress 1.0.3 pyhd8ed1ab_0 conda-forge
filelock 3.16.1 pyhd8ed1ab_0 conda-forge
font-ttf-dejavu-sans-mono 2.37 hab24e00_0 conda-forge
font-ttf-inconsolata 3.000 h77eed37_0 conda-forge
font-ttf-source-code-pro 2.038 h77eed37_0 conda-forge
font-ttf-ubuntu 0.83 h77eed37_3 conda-forge
fontconfig 2.14.2 h82840c6_0 conda-forge
fonts-conda-ecosystem 1 0 conda-forge
fonts-conda-forge 1 0 conda-forge
fonttools 4.54.1 py312h024a12e_0 conda-forge
freetype 2.12.1 hadb7bae_2 conda-forge
fribidi 1.0.10 h27ca646_0 conda-forge
fsspec 2024.9.0 pyhff2d567_0 conda-forge
gdk-pixbuf 2.42.12 h7ddc832_0 conda-forge
gflags 2.2.2 hf9b8971_1005 conda-forge
glog 0.7.1 heb240a5_0 conda-forge
gmp 6.3.0 h7bae524_2 conda-forge
gmpy2 2.1.5 py312h87fada9_2 conda-forge
graphite2 1.3.13 hebf3989_1003 conda-forge
graphviz 12.0.0 hbf8cc41_0 conda-forge
gtk2 2.24.33 h91d5085_5 conda-forge
gts 0.7.6 he42f4ea_4 conda-forge
h2 4.1.0 pyhd8ed1ab_0 conda-forge
h5netcdf 1.4.0 pyhd8ed1ab_0 conda-forge
h5py 3.11.0 nompi_py312h903599c_102 conda-forge
harfbuzz 9.0.0 h997cde5_1 conda-forge
hdf5 1.14.3 nompi_hec07895_105 conda-forge
hpack 4.0.0 pyh9f0ad1d_0 conda-forge
huggingface_hub 0.25.2 pyh0610db2_0 conda-forge
hyperframe 6.0.1 pyhd8ed1ab_0 conda-forge
icu 75.1 hfee45f7_0 conda-forge
idna 3.10 pyhd8ed1ab_0 conda-forge
importlib-metadata 8.5.0 pyha770c72_0 conda-forge
jax 0.4.31 pyhd8ed1ab_1 conda-forge
jaxlib 0.4.31 cpu_py312h47007b3_1 conda-forge
jaxopt 0.8.3 pyhd8ed1ab_0 conda-forge
jinja2 3.1.4 pyhd8ed1ab_0 conda-forge
joblib 1.4.2 pyhd8ed1ab_0 conda-forge
kiwisolver 1.4.7 py312h6142ec9_0 conda-forge
krb5 1.21.3 h237132a_0 conda-forge
lcms2 2.16 ha0e7c42_0 conda-forge
ld64 951.9 h39a299f_1 conda-forge
ld64_osx-arm64 951.9 hc81425b_1 conda-forge
lerc 4.0.0 h9a09cb3_0 conda-forge
libabseil 20240116.2 cxx17_h00cdb27_1 conda-forge
libaec 1.1.3 hebf3989_0 conda-forge
libarrow 17.0.0 hc6a7651_16_cpu conda-forge
libblas 3.9.0 24_osxarm64_openblas conda-forge
libbrotlicommon 1.1.0 hd74edd7_2 conda-forge
libbrotlidec 1.1.0 hd74edd7_2 conda-forge
libbrotlienc 1.1.0 hd74edd7_2 conda-forge
libcblas 3.9.0 24_osxarm64_openblas conda-forge
libclang-cpp17 17.0.6 default_h146c034_7 conda-forge
libcrc32c 1.1.2 hbdafb3b_0 conda-forge
libcurl 8.10.1 h13a7ad3_0 conda-forge
libcxx 19.1.1 ha82da77_0 conda-forge
libcxx-devel 17.0.6 h86353a2_6 conda-forge
libdeflate 1.22 hd74edd7_0 conda-forge
libedit 3.1.20191231 hc8eb9b7_2 conda-forge
libev 4.33 h93a5062_2 conda-forge
libexpat 2.6.3 hf9b8971_0 conda-forge
libffi 3.4.2 h3422bc3_5 conda-forge
libgd 2.3.3 hac1b3a8_10 conda-forge
libgfortran 5.0.0 13_2_0_hd922786_3 conda-forge
libgfortran5 13.2.0 hf226fd6_3 conda-forge
libglib 2.82.1 h4821c08_0 conda-forge
libgoogle-cloud 2.29.0 hfa33a2f_0 conda-forge
libgoogle-cloud-storage 2.29.0 h90fd6fa_0 conda-forge
libgrpc 1.62.2 h9c18a4f_0 conda-forge
libiconv 1.17 h0d3ecfb_2 conda-forge
libintl 0.22.5 h8414b35_3 conda-forge
libjpeg-turbo 3.0.0 hb547adb_1 conda-forge
liblapack 3.9.0 24_osxarm64_openblas conda-forge
liblapacke 3.9.0 24_osxarm64_openblas conda-forge
libllvm14 14.0.6 hd1a9a77_4 conda-forge
libllvm17 17.0.6 h5090b49_2 conda-forge
libnghttp2 1.58.0 ha4dd798_1 conda-forge
libopenblas 0.3.27 openmp_h517c56d_1 conda-forge
libpng 1.6.44 hc14010f_0 conda-forge
libprotobuf 4.25.3 hc39d83c_1 conda-forge
libre2-11 2023.09.01 h7b2c953_2 conda-forge
librsvg 2.58.4 h40956f1_0 conda-forge
libsqlite 3.46.1 hc14010f_0 conda-forge
libssh2 1.11.0 h7a5bd25_0 conda-forge
libtiff 4.7.0 hfce79cd_1 conda-forge
libtorch 2.4.1 cpu_generic_h123b01e_0 conda-forge
libutf8proc 2.8.0 h1a8c8d9_0 conda-forge
libuv 1.49.0 hd74edd7_0 conda-forge
libwebp-base 1.4.0 h93a5062_0 conda-forge
libxcb 1.17.0 hdb1d25a_0 conda-forge
libxml2 2.12.7 h01dff8b_4 conda-forge
libzlib 1.3.1 h8359307_2 conda-forge
llvm-openmp 19.1.1 h6cdba0f_0 conda-forge
llvm-tools 17.0.6 h5090b49_2 conda-forge
llvmlite 0.43.0 py312ha9ca408_1 conda-forge
logical-unification 0.4.6 pyhd8ed1ab_0 conda-forge
lz4-c 1.9.4 hb7217d7_0 conda-forge
macosx_deployment_target_osx-arm64 11.0 h6553868_1 conda-forge
markdown-it-py 3.0.0 pyhd8ed1ab_0 conda-forge
markupsafe 3.0.1 py312h906988d_1 conda-forge
matplotlib 3.9.2 py312h1f38498_1 conda-forge
matplotlib-base 3.9.2 py312h9bd0bc6_1 conda-forge
mdurl 0.1.2 pyhd8ed1ab_0 conda-forge
minikanren 1.0.3 pyhd8ed1ab_0 conda-forge
ml_dtypes 0.5.0 py312hcd31e36_0 conda-forge
mpc 1.3.1 h8f1351a_1 conda-forge
mpfr 4.2.1 hb693164_3 conda-forge
mpmath 1.3.0 pyhd8ed1ab_0 conda-forge
multipledispatch 0.6.0 pyhd8ed1ab_1 conda-forge
munkres 1.1.4 pyh9f0ad1d_0 conda-forge
ncurses 6.5 h7bae524_1 conda-forge
networkx 3.4 pyhd8ed1ab_0 conda-forge
nomkl 1.0 h5ca1d4c_0 conda-forge
numba 0.60.0 py312h41cea2d_0 conda-forge
numpy 1.26.4 py312h8442bc7_0 conda-forge
numpyro 0.15.3 pyhd8ed1ab_0 conda-forge
nutpie 0.13.2 py312headafe2_0 conda-forge
openblas 0.3.27 openmp_h560b219_1 conda-forge
openjpeg 2.5.2 h9f1df11_0 conda-forge
openssl 3.3.2 h8359307_0 conda-forge
opt-einsum 3.4.0 hd8ed1ab_0 conda-forge
opt_einsum 3.4.0 pyhd8ed1ab_0 conda-forge
optax 0.2.3 pyhd8ed1ab_0 conda-forge
orc 2.0.2 h75dedd0_0 conda-forge
packaging 24.1 pyhd8ed1ab_0 conda-forge
pandas 2.2.3 py312hcd31e36_1 conda-forge
pango 1.54.0 h9ee27a3_2 conda-forge
pcre2 10.44 h297a79d_2 conda-forge
pillow 10.4.0 py312h8609ca0_1 conda-forge
pip 24.2 pyh8b19718_1 conda-forge
pixman 0.43.4 hebf3989_0 conda-forge
psutil 6.0.0 py312h024a12e_1 conda-forge
pthread-stubs 0.4 hd74edd7_1002 conda-forge
pyarrow-core 17.0.0 py312he20ac61_1_cpu conda-forge
pycparser 2.22 pyhd8ed1ab_0 conda-forge
pygments 2.18.0 pyhd8ed1ab_0 conda-forge
pymc 5.17.0 hd8ed1ab_0 conda-forge
pymc-base 5.17.0 pyhd8ed1ab_0 conda-forge
pyparsing 3.1.4 pyhd8ed1ab_0 conda-forge
pysocks 1.7.1 pyha2e5f31_6 conda-forge
pytensor 2.25.5 py312h3f593ad_0 conda-forge
pytensor-base 2.25.5 py312h02baea5_0 conda-forge
python 3.12.7 h739c21a_0_cpython conda-forge
python-dateutil 2.9.0 pyhd8ed1ab_0 conda-forge
python-graphviz 0.20.3 pyhe28f650_1 conda-forge
python-tzdata 2024.2 pyhd8ed1ab_0 conda-forge
python_abi 3.12 5_cp312 conda-forge
pytorch 2.4.1 cpu_generic_py312h40771f0_0 conda-forge
pytz 2024.1 pyhd8ed1ab_0 conda-forge
pyyaml 6.0.2 py312h024a12e_1 conda-forge
qhull 2020.2 h420ef59_5 conda-forge
re2 2023.09.01 h4cba328_2 conda-forge
readline 8.2 h92ec313_1 conda-forge
requests 2.32.3 pyhd8ed1ab_0 conda-forge
rich 13.9.2 pyhd8ed1ab_0 conda-forge
safetensors 0.4.5 py312he431725_0 conda-forge
scikit-learn 1.5.2 py312h387f99c_1 conda-forge
scipy 1.14.1 py312heb3a901_0 conda-forge
setuptools 75.1.0 pyhd8ed1ab_0 conda-forge
sigtool 0.1.3 h44b9a77_0 conda-forge
six 1.16.0 pyh6c4a22f_0 conda-forge
sleef 3.7 h7783ee8_0 conda-forge
snappy 1.2.1 hd02b534_0 conda-forge
sympy 1.13.3 pyh2585a3b_104 conda-forge
tabulate 0.9.0 pyhd8ed1ab_1 conda-forge
tapi 1300.6.5 h03f4b80_0 conda-forge
threadpoolctl 3.5.0 pyhc1e730c_0 conda-forge
tk 8.6.13 h5083fa2_1 conda-forge
toolz 1.0.0 pyhd8ed1ab_0 conda-forge
tornado 6.4.1 py312h024a12e_1 conda-forge
tqdm 4.66.5 pyhd8ed1ab_0 conda-forge
typing-extensions 4.12.2 hd8ed1ab_0 conda-forge
typing_extensions 4.12.2 pyha770c72_0 conda-forge
tzdata 2024b hc8b5060_0 conda-forge
urllib3 2.2.3 pyhd8ed1ab_0 conda-forge
wheel 0.44.0 pyhd8ed1ab_0 conda-forge
xarray 2024.9.0 pyhd8ed1ab_1 conda-forge
xarray-einstats 0.8.0 pyhd8ed1ab_0 conda-forge
xorg-libxau 1.0.11 hd74edd7_1 conda-forge
xorg-libxdmcp 1.1.5 hd74edd7_0 conda-forge
xz 5.2.6 h57fd34a_0 conda-forge
yaml 0.2.5 h3422bc3_2 conda-forge
zipp 3.20.2 pyhd8ed1ab_0 conda-forge
zlib 1.3.1 h8359307_2 conda-forge
zstandard 0.23.0 py312h15fbf35_1 conda-forge
zstd 1.5.6 hb46c0d2_0 conda-forge