nutpie icon indicating copy to clipboard operation
nutpie copied to clipboard

JAX backend fails for a simple `pymc` linear regression model

Open trendelkampschroer opened this issue 1 year ago • 10 comments

Minimal example

import time

import arviz as az
import numpy as np
import nutpie
import pandas as pd
import pymc as pm

BETA = [1.0, -1.0, 2.0, -2.0]
SIGMA = 10.0


def generate_data(num_samples: int = 1000) -> pd.DataFrame:
    rng = np.random.default_rng(42)
    dims = len(BETA)
    X = rng.normal(size=(num_samples, dims))
    y = X.dot(BETA) + SIGMA * rng.normal(size=num_samples)
    frame = pd.DataFrame(data=X, columns=[f"x_{i+1}" for i in range(dims)])
    frame["y"] = y
    return frame


def make_model(frame: pd.DataFrame) -> pm.Model:
    predictors = [col for col in frame.columns if col.startswith("x")]
    observation_idx = [i for i in range(len(frame))]
    coords = {"observation_idx": observation_idx, "predictors": predictors}

    with pm.Model(coords=coords) as model:
        # Data
        x = pm.Data("x", frame[predictors], dims=["observation_idx", "predictor"])
        y = pm.Data("y", frame["y"], dims="observation_idx")

        # Population level
        beta = pm.Normal("beta", mu=0.0, sigma=1.0, dims="predictor")
        sigma = pm.HalfNormal("sigma", sigma=10.0)

        # Linear model
        mu = (beta * x).sum(axis=-1)

        # Likelihood
        pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y, shape=mu.shape)

    return model


if __name__ == "__main__":
    frame = generate_data(num_samples=10_000)
    model = make_model(frame)

    kwargs = dict(backend="jax", gradient_backend="jax")
    t0 = time.time()
    trace = nutpie.sample(nutpie.compile_pymc_model(model, **kwargs))
    t = time.time() - t0
    print(f"Time for nutpie (compiled, {kwargs=}) sampling is {t=:.3f}s.")
    summary = az.summary(trace, var_names=["beta", "sigma"], round_to=4).loc[:, ["mean", "hdi_3%", "hdi_97%", "ess_bulk", "ess_tail", "r_hat"]]
    print(summary)

Error message

thread 'nutpie-worker-3' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
Traceback (most recent call last):
  File ".../linear_regression.py", line 52, in <module>
thread 'nutpie-worker-0' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
    trace = nutpie.sample(nutpie.compile_pymc_model(model, **kwargs))
            ^^thread 'nutpie-worker-6' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
^^thread 'nutpie-worker-1' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
^^^^^^^^^^thread 'nutpie-worker-2' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.12/site-packages/nutpie/sample.py", line 636, in sample
    result = sampler.wait()
             ^^^^^^^^^^^^^^
  File ".../lib/python3.12/site-packages/nutpie/sample.py", line 388, in wait
    self._sampler.wait(timeout)
RuntimeError: All initialization points failed

Caused by:
    Logp function returned error: Python error: TypeError: _compile_pymc_model_jax.<locals>.make_logp_func.<locals>.logp() got multiple values for argument 'x'

Sampling with backend="numba" and gradient_backend="pytensor" runs successfully.

Version

# packages in environment at .../.miniconda3/envs/pymc:
#
# Name                    Version                   Build  Channel
absl-py                   2.1.0              pyhd8ed1ab_0    conda-forge
accelerate                1.0.0              pyhd8ed1ab_0    conda-forge
arviz                     0.20.0             pyhd8ed1ab_0    conda-forge
atk-1.0                   2.38.0               hd03087b_2    conda-forge
aws-c-auth                0.7.31               hc27b277_0    conda-forge
aws-c-cal                 0.7.4                h41dd001_1    conda-forge
aws-c-common              0.9.28               hd74edd7_0    conda-forge
aws-c-compression         0.2.19               h41dd001_1    conda-forge
aws-c-event-stream        0.4.3                h40a8fc1_2    conda-forge
aws-c-http                0.8.10               hf5a2c8c_0    conda-forge
aws-c-io                  0.14.18             hc3cb426_12    conda-forge
aws-c-mqtt                0.10.7               h3acc7b9_0    conda-forge
aws-c-s3                  0.6.6                hd16c091_0    conda-forge
aws-c-sdkutils            0.1.19               h41dd001_3    conda-forge
aws-checksums             0.1.20               h41dd001_0    conda-forge
aws-crt-cpp               0.28.3               h433f80b_6    conda-forge
aws-sdk-cpp               1.11.407             h0455a66_0    conda-forge
azure-core-cpp            1.13.0               hd01fc5c_0    conda-forge
azure-identity-cpp        1.8.0                h13ea094_2    conda-forge
azure-storage-blobs-cpp   12.12.0              hfde595f_0    conda-forge
azure-storage-common-cpp  12.7.0               hcf3b6fd_1    conda-forge
azure-storage-files-datalake-cpp 12.11.0              h082e32e_1    conda-forge
blackjax                  1.2.4              pyhd8ed1ab_0    conda-forge
blas                      2.124                  openblas    conda-forge
blas-devel                3.9.0           24_osxarm64_openblas    conda-forge
brotli                    1.1.0                hd74edd7_2    conda-forge
brotli-bin                1.1.0                hd74edd7_2    conda-forge
brotli-python             1.1.0           py312hde4cb15_2    conda-forge
bzip2                     1.0.8                h99b78c6_7    conda-forge
c-ares                    1.34.1               hd74edd7_0    conda-forge
c-compiler                1.8.0                h2664225_0    conda-forge
ca-certificates           2024.8.30            hf0a4a13_0    conda-forge
cached-property           1.5.2                hd8ed1ab_1    conda-forge
cached_property           1.5.2              pyha770c72_1    conda-forge
cachetools                5.5.0              pyhd8ed1ab_0    conda-forge
cairo                     1.18.0               hb4a6bf7_3    conda-forge
cctools                   1010.6               hf67d63f_1    conda-forge
cctools_osx-arm64         1010.6               h4208deb_1    conda-forge
certifi                   2024.8.30          pyhd8ed1ab_0    conda-forge
cffi                      1.17.1          py312h0fad829_0    conda-forge
charset-normalizer        3.4.0              pyhd8ed1ab_0    conda-forge
chex                      0.1.87             pyhd8ed1ab_0    conda-forge
clang                     17.0.6          default_h360f5da_7    conda-forge
clang-17                  17.0.6          default_h146c034_7    conda-forge
clang_impl_osx-arm64      17.0.6              he47c785_21    conda-forge
clang_osx-arm64           17.0.6              h54d7cd3_21    conda-forge
clangxx                   17.0.6          default_h360f5da_7    conda-forge
clangxx_impl_osx-arm64    17.0.6              h50f59cd_21    conda-forge
clangxx_osx-arm64         17.0.6              h54d7cd3_21    conda-forge
cloudpickle               3.0.0              pyhd8ed1ab_0    conda-forge
colorama                  0.4.6              pyhd8ed1ab_0    conda-forge
compiler-rt               17.0.6               h856b3c1_2    conda-forge
compiler-rt_osx-arm64     17.0.6               h832e737_2    conda-forge
cons                      0.4.6              pyhd8ed1ab_0    conda-forge
contourpy                 1.3.0           py312h6142ec9_2    conda-forge
cpython                   3.12.7          py312hd8ed1ab_0    conda-forge
cxx-compiler              1.8.0                he8d86c4_0    conda-forge
cycler                    0.12.1             pyhd8ed1ab_0    conda-forge
etils                     1.9.4              pyhd8ed1ab_0    conda-forge
etuples                   0.3.9              pyhd8ed1ab_0    conda-forge
expat                     2.6.3                hf9b8971_0    conda-forge
fastprogress              1.0.3              pyhd8ed1ab_0    conda-forge
filelock                  3.16.1             pyhd8ed1ab_0    conda-forge
font-ttf-dejavu-sans-mono 2.37                 hab24e00_0    conda-forge
font-ttf-inconsolata      3.000                h77eed37_0    conda-forge
font-ttf-source-code-pro  2.038                h77eed37_0    conda-forge
font-ttf-ubuntu           0.83                 h77eed37_3    conda-forge
fontconfig                2.14.2               h82840c6_0    conda-forge
fonts-conda-ecosystem     1                             0    conda-forge
fonts-conda-forge         1                             0    conda-forge
fonttools                 4.54.1          py312h024a12e_0    conda-forge
freetype                  2.12.1               hadb7bae_2    conda-forge
fribidi                   1.0.10               h27ca646_0    conda-forge
fsspec                    2024.9.0           pyhff2d567_0    conda-forge
gdk-pixbuf                2.42.12              h7ddc832_0    conda-forge
gflags                    2.2.2             hf9b8971_1005    conda-forge
glog                      0.7.1                heb240a5_0    conda-forge
gmp                       6.3.0                h7bae524_2    conda-forge
gmpy2                     2.1.5           py312h87fada9_2    conda-forge
graphite2                 1.3.13            hebf3989_1003    conda-forge
graphviz                  12.0.0               hbf8cc41_0    conda-forge
gtk2                      2.24.33              h91d5085_5    conda-forge
gts                       0.7.6                he42f4ea_4    conda-forge
h2                        4.1.0              pyhd8ed1ab_0    conda-forge
h5netcdf                  1.4.0              pyhd8ed1ab_0    conda-forge
h5py                      3.11.0          nompi_py312h903599c_102    conda-forge
harfbuzz                  9.0.0                h997cde5_1    conda-forge
hdf5                      1.14.3          nompi_hec07895_105    conda-forge
hpack                     4.0.0              pyh9f0ad1d_0    conda-forge
huggingface_hub           0.25.2             pyh0610db2_0    conda-forge
hyperframe                6.0.1              pyhd8ed1ab_0    conda-forge
icu                       75.1                 hfee45f7_0    conda-forge
idna                      3.10               pyhd8ed1ab_0    conda-forge
importlib-metadata        8.5.0              pyha770c72_0    conda-forge
jax                       0.4.31             pyhd8ed1ab_1    conda-forge
jaxlib                    0.4.31          cpu_py312h47007b3_1    conda-forge
jaxopt                    0.8.3              pyhd8ed1ab_0    conda-forge
jinja2                    3.1.4              pyhd8ed1ab_0    conda-forge
joblib                    1.4.2              pyhd8ed1ab_0    conda-forge
kiwisolver                1.4.7           py312h6142ec9_0    conda-forge
krb5                      1.21.3               h237132a_0    conda-forge
lcms2                     2.16                 ha0e7c42_0    conda-forge
ld64                      951.9                h39a299f_1    conda-forge
ld64_osx-arm64            951.9                hc81425b_1    conda-forge
lerc                      4.0.0                h9a09cb3_0    conda-forge
libabseil                 20240116.2      cxx17_h00cdb27_1    conda-forge
libaec                    1.1.3                hebf3989_0    conda-forge
libarrow                  17.0.0          hc6a7651_16_cpu    conda-forge
libblas                   3.9.0           24_osxarm64_openblas    conda-forge
libbrotlicommon           1.1.0                hd74edd7_2    conda-forge
libbrotlidec              1.1.0                hd74edd7_2    conda-forge
libbrotlienc              1.1.0                hd74edd7_2    conda-forge
libcblas                  3.9.0           24_osxarm64_openblas    conda-forge
libclang-cpp17            17.0.6          default_h146c034_7    conda-forge
libcrc32c                 1.1.2                hbdafb3b_0    conda-forge
libcurl                   8.10.1               h13a7ad3_0    conda-forge
libcxx                    19.1.1               ha82da77_0    conda-forge
libcxx-devel              17.0.6               h86353a2_6    conda-forge
libdeflate                1.22                 hd74edd7_0    conda-forge
libedit                   3.1.20191231         hc8eb9b7_2    conda-forge
libev                     4.33                 h93a5062_2    conda-forge
libexpat                  2.6.3                hf9b8971_0    conda-forge
libffi                    3.4.2                h3422bc3_5    conda-forge
libgd                     2.3.3               hac1b3a8_10    conda-forge
libgfortran               5.0.0           13_2_0_hd922786_3    conda-forge
libgfortran5              13.2.0               hf226fd6_3    conda-forge
libglib                   2.82.1               h4821c08_0    conda-forge
libgoogle-cloud           2.29.0               hfa33a2f_0    conda-forge
libgoogle-cloud-storage   2.29.0               h90fd6fa_0    conda-forge
libgrpc                   1.62.2               h9c18a4f_0    conda-forge
libiconv                  1.17                 h0d3ecfb_2    conda-forge
libintl                   0.22.5               h8414b35_3    conda-forge
libjpeg-turbo             3.0.0                hb547adb_1    conda-forge
liblapack                 3.9.0           24_osxarm64_openblas    conda-forge
liblapacke                3.9.0           24_osxarm64_openblas    conda-forge
libllvm14                 14.0.6               hd1a9a77_4    conda-forge
libllvm17                 17.0.6               h5090b49_2    conda-forge
libnghttp2                1.58.0               ha4dd798_1    conda-forge
libopenblas               0.3.27          openmp_h517c56d_1    conda-forge
libpng                    1.6.44               hc14010f_0    conda-forge
libprotobuf               4.25.3               hc39d83c_1    conda-forge
libre2-11                 2023.09.01           h7b2c953_2    conda-forge
librsvg                   2.58.4               h40956f1_0    conda-forge
libsqlite                 3.46.1               hc14010f_0    conda-forge
libssh2                   1.11.0               h7a5bd25_0    conda-forge
libtiff                   4.7.0                hfce79cd_1    conda-forge
libtorch                  2.4.1           cpu_generic_h123b01e_0    conda-forge
libutf8proc               2.8.0                h1a8c8d9_0    conda-forge
libuv                     1.49.0               hd74edd7_0    conda-forge
libwebp-base              1.4.0                h93a5062_0    conda-forge
libxcb                    1.17.0               hdb1d25a_0    conda-forge
libxml2                   2.12.7               h01dff8b_4    conda-forge
libzlib                   1.3.1                h8359307_2    conda-forge
llvm-openmp               19.1.1               h6cdba0f_0    conda-forge
llvm-tools                17.0.6               h5090b49_2    conda-forge
llvmlite                  0.43.0          py312ha9ca408_1    conda-forge
logical-unification       0.4.6              pyhd8ed1ab_0    conda-forge
lz4-c                     1.9.4                hb7217d7_0    conda-forge
macosx_deployment_target_osx-arm64 11.0                 h6553868_1    conda-forge
markdown-it-py            3.0.0              pyhd8ed1ab_0    conda-forge
markupsafe                3.0.1           py312h906988d_1    conda-forge
matplotlib                3.9.2           py312h1f38498_1    conda-forge
matplotlib-base           3.9.2           py312h9bd0bc6_1    conda-forge
mdurl                     0.1.2              pyhd8ed1ab_0    conda-forge
minikanren                1.0.3              pyhd8ed1ab_0    conda-forge
ml_dtypes                 0.5.0           py312hcd31e36_0    conda-forge
mpc                       1.3.1                h8f1351a_1    conda-forge
mpfr                      4.2.1                hb693164_3    conda-forge
mpmath                    1.3.0              pyhd8ed1ab_0    conda-forge
multipledispatch          0.6.0              pyhd8ed1ab_1    conda-forge
munkres                   1.1.4              pyh9f0ad1d_0    conda-forge
ncurses                   6.5                  h7bae524_1    conda-forge
networkx                  3.4                pyhd8ed1ab_0    conda-forge
nomkl                     1.0                  h5ca1d4c_0    conda-forge
numba                     0.60.0          py312h41cea2d_0    conda-forge
numpy                     1.26.4          py312h8442bc7_0    conda-forge
numpyro                   0.15.3             pyhd8ed1ab_0    conda-forge
nutpie                    0.13.2          py312headafe2_0    conda-forge
openblas                  0.3.27          openmp_h560b219_1    conda-forge
openjpeg                  2.5.2                h9f1df11_0    conda-forge
openssl                   3.3.2                h8359307_0    conda-forge
opt-einsum                3.4.0                hd8ed1ab_0    conda-forge
opt_einsum                3.4.0              pyhd8ed1ab_0    conda-forge
optax                     0.2.3              pyhd8ed1ab_0    conda-forge
orc                       2.0.2                h75dedd0_0    conda-forge
packaging                 24.1               pyhd8ed1ab_0    conda-forge
pandas                    2.2.3           py312hcd31e36_1    conda-forge
pango                     1.54.0               h9ee27a3_2    conda-forge
pcre2                     10.44                h297a79d_2    conda-forge
pillow                    10.4.0          py312h8609ca0_1    conda-forge
pip                       24.2               pyh8b19718_1    conda-forge
pixman                    0.43.4               hebf3989_0    conda-forge
psutil                    6.0.0           py312h024a12e_1    conda-forge
pthread-stubs             0.4               hd74edd7_1002    conda-forge
pyarrow-core              17.0.0          py312he20ac61_1_cpu    conda-forge
pycparser                 2.22               pyhd8ed1ab_0    conda-forge
pygments                  2.18.0             pyhd8ed1ab_0    conda-forge
pymc                      5.17.0               hd8ed1ab_0    conda-forge
pymc-base                 5.17.0             pyhd8ed1ab_0    conda-forge
pyparsing                 3.1.4              pyhd8ed1ab_0    conda-forge
pysocks                   1.7.1              pyha2e5f31_6    conda-forge
pytensor                  2.25.5          py312h3f593ad_0    conda-forge
pytensor-base             2.25.5          py312h02baea5_0    conda-forge
python                    3.12.7          h739c21a_0_cpython    conda-forge
python-dateutil           2.9.0              pyhd8ed1ab_0    conda-forge
python-graphviz           0.20.3             pyhe28f650_1    conda-forge
python-tzdata             2024.2             pyhd8ed1ab_0    conda-forge
python_abi                3.12                    5_cp312    conda-forge
pytorch                   2.4.1           cpu_generic_py312h40771f0_0    conda-forge
pytz                      2024.1             pyhd8ed1ab_0    conda-forge
pyyaml                    6.0.2           py312h024a12e_1    conda-forge
qhull                     2020.2               h420ef59_5    conda-forge
re2                       2023.09.01           h4cba328_2    conda-forge
readline                  8.2                  h92ec313_1    conda-forge
requests                  2.32.3             pyhd8ed1ab_0    conda-forge
rich                      13.9.2             pyhd8ed1ab_0    conda-forge
safetensors               0.4.5           py312he431725_0    conda-forge
scikit-learn              1.5.2           py312h387f99c_1    conda-forge
scipy                     1.14.1          py312heb3a901_0    conda-forge
setuptools                75.1.0             pyhd8ed1ab_0    conda-forge
sigtool                   0.1.3                h44b9a77_0    conda-forge
six                       1.16.0             pyh6c4a22f_0    conda-forge
sleef                     3.7                  h7783ee8_0    conda-forge
snappy                    1.2.1                hd02b534_0    conda-forge
sympy                     1.13.3           pyh2585a3b_104    conda-forge
tabulate                  0.9.0              pyhd8ed1ab_1    conda-forge
tapi                      1300.6.5             h03f4b80_0    conda-forge
threadpoolctl             3.5.0              pyhc1e730c_0    conda-forge
tk                        8.6.13               h5083fa2_1    conda-forge
toolz                     1.0.0              pyhd8ed1ab_0    conda-forge
tornado                   6.4.1           py312h024a12e_1    conda-forge
tqdm                      4.66.5             pyhd8ed1ab_0    conda-forge
typing-extensions         4.12.2               hd8ed1ab_0    conda-forge
typing_extensions         4.12.2             pyha770c72_0    conda-forge
tzdata                    2024b                hc8b5060_0    conda-forge
urllib3                   2.2.3              pyhd8ed1ab_0    conda-forge
wheel                     0.44.0             pyhd8ed1ab_0    conda-forge
xarray                    2024.9.0           pyhd8ed1ab_1    conda-forge
xarray-einstats           0.8.0              pyhd8ed1ab_0    conda-forge
xorg-libxau               1.0.11               hd74edd7_1    conda-forge
xorg-libxdmcp             1.1.5                hd74edd7_0    conda-forge
xz                        5.2.6                h57fd34a_0    conda-forge
yaml                      0.2.5                h3422bc3_2    conda-forge
zipp                      3.20.2             pyhd8ed1ab_0    conda-forge
zlib                      1.3.1                h8359307_2    conda-forge
zstandard                 0.23.0          py312h15fbf35_1    conda-forge
zstd                      1.5.6                hb46c0d2_0    conda-forge

trendelkampschroer avatar Oct 21 '24 10:10 trendelkampschroer