nutpie
nutpie copied to clipboard
JAX backend fails for a simple `pymc` linear regression model
Minimal example
import time
import arviz as az
import numpy as np
import nutpie
import pandas as pd
import pymc as pm
BETA = [1.0, -1.0, 2.0, -2.0]
SIGMA = 10.0
def generate_data(num_samples: int = 1000) -> pd.DataFrame:
rng = np.random.default_rng(42)
dims = len(BETA)
X = rng.normal(size=(num_samples, dims))
y = X.dot(BETA) + SIGMA * rng.normal(size=num_samples)
frame = pd.DataFrame(data=X, columns=[f"x_{i+1}" for i in range(dims)])
frame["y"] = y
return frame
def make_model(frame: pd.DataFrame) -> pm.Model:
predictors = [col for col in frame.columns if col.startswith("x")]
observation_idx = [i for i in range(len(frame))]
coords = {"observation_idx": observation_idx, "predictors": predictors}
with pm.Model(coords=coords) as model:
# Data
x = pm.Data("x", frame[predictors], dims=["observation_idx", "predictor"])
y = pm.Data("y", frame["y"], dims="observation_idx")
# Population level
beta = pm.Normal("beta", mu=0.0, sigma=1.0, dims="predictor")
sigma = pm.HalfNormal("sigma", sigma=10.0)
# Linear model
mu = (beta * x).sum(axis=-1)
# Likelihood
pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y, shape=mu.shape)
return model
if __name__ == "__main__":
frame = generate_data(num_samples=10_000)
model = make_model(frame)
kwargs = dict(backend="jax", gradient_backend="jax")
t0 = time.time()
trace = nutpie.sample(nutpie.compile_pymc_model(model, **kwargs))
t = time.time() - t0
print(f"Time for nutpie (compiled, {kwargs=}) sampling is {t=:.3f}s.")
summary = az.summary(trace, var_names=["beta", "sigma"], round_to=4).loc[:, ["mean", "hdi_3%", "hdi_97%", "ess_bulk", "ess_tail", "r_hat"]]
print(summary)
Error message
thread 'nutpie-worker-3' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
Traceback (most recent call last):
File ".../linear_regression.py", line 52, in <module>
thread 'nutpie-worker-0' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
trace = nutpie.sample(nutpie.compile_pymc_model(model, **kwargs))
^^thread 'nutpie-worker-6' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
^^thread 'nutpie-worker-1' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
^^^^^^^^^^thread 'nutpie-worker-2' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lib/python3.12/site-packages/nutpie/sample.py", line 636, in sample
result = sampler.wait()
^^^^^^^^^^^^^^
File ".../lib/python3.12/site-packages/nutpie/sample.py", line 388, in wait
self._sampler.wait(timeout)
RuntimeError: All initialization points failed
Caused by:
Logp function returned error: Python error: TypeError: _compile_pymc_model_jax.<locals>.make_logp_func.<locals>.logp() got multiple values for argument 'x'
Sampling with backend="numba" and gradient_backend="pytensor" runs successfully.
Version
# packages in environment at .../.miniconda3/envs/pymc:
#
# Name Version Build Channel
absl-py 2.1.0 pyhd8ed1ab_0 conda-forge
accelerate 1.0.0 pyhd8ed1ab_0 conda-forge
arviz 0.20.0 pyhd8ed1ab_0 conda-forge
atk-1.0 2.38.0 hd03087b_2 conda-forge
aws-c-auth 0.7.31 hc27b277_0 conda-forge
aws-c-cal 0.7.4 h41dd001_1 conda-forge
aws-c-common 0.9.28 hd74edd7_0 conda-forge
aws-c-compression 0.2.19 h41dd001_1 conda-forge
aws-c-event-stream 0.4.3 h40a8fc1_2 conda-forge
aws-c-http 0.8.10 hf5a2c8c_0 conda-forge
aws-c-io 0.14.18 hc3cb426_12 conda-forge
aws-c-mqtt 0.10.7 h3acc7b9_0 conda-forge
aws-c-s3 0.6.6 hd16c091_0 conda-forge
aws-c-sdkutils 0.1.19 h41dd001_3 conda-forge
aws-checksums 0.1.20 h41dd001_0 conda-forge
aws-crt-cpp 0.28.3 h433f80b_6 conda-forge
aws-sdk-cpp 1.11.407 h0455a66_0 conda-forge
azure-core-cpp 1.13.0 hd01fc5c_0 conda-forge
azure-identity-cpp 1.8.0 h13ea094_2 conda-forge
azure-storage-blobs-cpp 12.12.0 hfde595f_0 conda-forge
azure-storage-common-cpp 12.7.0 hcf3b6fd_1 conda-forge
azure-storage-files-datalake-cpp 12.11.0 h082e32e_1 conda-forge
blackjax 1.2.4 pyhd8ed1ab_0 conda-forge
blas 2.124 openblas conda-forge
blas-devel 3.9.0 24_osxarm64_openblas conda-forge
brotli 1.1.0 hd74edd7_2 conda-forge
brotli-bin 1.1.0 hd74edd7_2 conda-forge
brotli-python 1.1.0 py312hde4cb15_2 conda-forge
bzip2 1.0.8 h99b78c6_7 conda-forge
c-ares 1.34.1 hd74edd7_0 conda-forge
c-compiler 1.8.0 h2664225_0 conda-forge
ca-certificates 2024.8.30 hf0a4a13_0 conda-forge
cached-property 1.5.2 hd8ed1ab_1 conda-forge
cached_property 1.5.2 pyha770c72_1 conda-forge
cachetools 5.5.0 pyhd8ed1ab_0 conda-forge
cairo 1.18.0 hb4a6bf7_3 conda-forge
cctools 1010.6 hf67d63f_1 conda-forge
cctools_osx-arm64 1010.6 h4208deb_1 conda-forge
certifi 2024.8.30 pyhd8ed1ab_0 conda-forge
cffi 1.17.1 py312h0fad829_0 conda-forge
charset-normalizer 3.4.0 pyhd8ed1ab_0 conda-forge
chex 0.1.87 pyhd8ed1ab_0 conda-forge
clang 17.0.6 default_h360f5da_7 conda-forge
clang-17 17.0.6 default_h146c034_7 conda-forge
clang_impl_osx-arm64 17.0.6 he47c785_21 conda-forge
clang_osx-arm64 17.0.6 h54d7cd3_21 conda-forge
clangxx 17.0.6 default_h360f5da_7 conda-forge
clangxx_impl_osx-arm64 17.0.6 h50f59cd_21 conda-forge
clangxx_osx-arm64 17.0.6 h54d7cd3_21 conda-forge
cloudpickle 3.0.0 pyhd8ed1ab_0 conda-forge
colorama 0.4.6 pyhd8ed1ab_0 conda-forge
compiler-rt 17.0.6 h856b3c1_2 conda-forge
compiler-rt_osx-arm64 17.0.6 h832e737_2 conda-forge
cons 0.4.6 pyhd8ed1ab_0 conda-forge
contourpy 1.3.0 py312h6142ec9_2 conda-forge
cpython 3.12.7 py312hd8ed1ab_0 conda-forge
cxx-compiler 1.8.0 he8d86c4_0 conda-forge
cycler 0.12.1 pyhd8ed1ab_0 conda-forge
etils 1.9.4 pyhd8ed1ab_0 conda-forge
etuples 0.3.9 pyhd8ed1ab_0 conda-forge
expat 2.6.3 hf9b8971_0 conda-forge
fastprogress 1.0.3 pyhd8ed1ab_0 conda-forge
filelock 3.16.1 pyhd8ed1ab_0 conda-forge
font-ttf-dejavu-sans-mono 2.37 hab24e00_0 conda-forge
font-ttf-inconsolata 3.000 h77eed37_0 conda-forge
font-ttf-source-code-pro 2.038 h77eed37_0 conda-forge
font-ttf-ubuntu 0.83 h77eed37_3 conda-forge
fontconfig 2.14.2 h82840c6_0 conda-forge
fonts-conda-ecosystem 1 0 conda-forge
fonts-conda-forge 1 0 conda-forge
fonttools 4.54.1 py312h024a12e_0 conda-forge
freetype 2.12.1 hadb7bae_2 conda-forge
fribidi 1.0.10 h27ca646_0 conda-forge
fsspec 2024.9.0 pyhff2d567_0 conda-forge
gdk-pixbuf 2.42.12 h7ddc832_0 conda-forge
gflags 2.2.2 hf9b8971_1005 conda-forge
glog 0.7.1 heb240a5_0 conda-forge
gmp 6.3.0 h7bae524_2 conda-forge
gmpy2 2.1.5 py312h87fada9_2 conda-forge
graphite2 1.3.13 hebf3989_1003 conda-forge
graphviz 12.0.0 hbf8cc41_0 conda-forge
gtk2 2.24.33 h91d5085_5 conda-forge
gts 0.7.6 he42f4ea_4 conda-forge
h2 4.1.0 pyhd8ed1ab_0 conda-forge
h5netcdf 1.4.0 pyhd8ed1ab_0 conda-forge
h5py 3.11.0 nompi_py312h903599c_102 conda-forge
harfbuzz 9.0.0 h997cde5_1 conda-forge
hdf5 1.14.3 nompi_hec07895_105 conda-forge
hpack 4.0.0 pyh9f0ad1d_0 conda-forge
huggingface_hub 0.25.2 pyh0610db2_0 conda-forge
hyperframe 6.0.1 pyhd8ed1ab_0 conda-forge
icu 75.1 hfee45f7_0 conda-forge
idna 3.10 pyhd8ed1ab_0 conda-forge
importlib-metadata 8.5.0 pyha770c72_0 conda-forge
jax 0.4.31 pyhd8ed1ab_1 conda-forge
jaxlib 0.4.31 cpu_py312h47007b3_1 conda-forge
jaxopt 0.8.3 pyhd8ed1ab_0 conda-forge
jinja2 3.1.4 pyhd8ed1ab_0 conda-forge
joblib 1.4.2 pyhd8ed1ab_0 conda-forge
kiwisolver 1.4.7 py312h6142ec9_0 conda-forge
krb5 1.21.3 h237132a_0 conda-forge
lcms2 2.16 ha0e7c42_0 conda-forge
ld64 951.9 h39a299f_1 conda-forge
ld64_osx-arm64 951.9 hc81425b_1 conda-forge
lerc 4.0.0 h9a09cb3_0 conda-forge
libabseil 20240116.2 cxx17_h00cdb27_1 conda-forge
libaec 1.1.3 hebf3989_0 conda-forge
libarrow 17.0.0 hc6a7651_16_cpu conda-forge
libblas 3.9.0 24_osxarm64_openblas conda-forge
libbrotlicommon 1.1.0 hd74edd7_2 conda-forge
libbrotlidec 1.1.0 hd74edd7_2 conda-forge
libbrotlienc 1.1.0 hd74edd7_2 conda-forge
libcblas 3.9.0 24_osxarm64_openblas conda-forge
libclang-cpp17 17.0.6 default_h146c034_7 conda-forge
libcrc32c 1.1.2 hbdafb3b_0 conda-forge
libcurl 8.10.1 h13a7ad3_0 conda-forge
libcxx 19.1.1 ha82da77_0 conda-forge
libcxx-devel 17.0.6 h86353a2_6 conda-forge
libdeflate 1.22 hd74edd7_0 conda-forge
libedit 3.1.20191231 hc8eb9b7_2 conda-forge
libev 4.33 h93a5062_2 conda-forge
libexpat 2.6.3 hf9b8971_0 conda-forge
libffi 3.4.2 h3422bc3_5 conda-forge
libgd 2.3.3 hac1b3a8_10 conda-forge
libgfortran 5.0.0 13_2_0_hd922786_3 conda-forge
libgfortran5 13.2.0 hf226fd6_3 conda-forge
libglib 2.82.1 h4821c08_0 conda-forge
libgoogle-cloud 2.29.0 hfa33a2f_0 conda-forge
libgoogle-cloud-storage 2.29.0 h90fd6fa_0 conda-forge
libgrpc 1.62.2 h9c18a4f_0 conda-forge
libiconv 1.17 h0d3ecfb_2 conda-forge
libintl 0.22.5 h8414b35_3 conda-forge
libjpeg-turbo 3.0.0 hb547adb_1 conda-forge
liblapack 3.9.0 24_osxarm64_openblas conda-forge
liblapacke 3.9.0 24_osxarm64_openblas conda-forge
libllvm14 14.0.6 hd1a9a77_4 conda-forge
libllvm17 17.0.6 h5090b49_2 conda-forge
libnghttp2 1.58.0 ha4dd798_1 conda-forge
libopenblas 0.3.27 openmp_h517c56d_1 conda-forge
libpng 1.6.44 hc14010f_0 conda-forge
libprotobuf 4.25.3 hc39d83c_1 conda-forge
libre2-11 2023.09.01 h7b2c953_2 conda-forge
librsvg 2.58.4 h40956f1_0 conda-forge
libsqlite 3.46.1 hc14010f_0 conda-forge
libssh2 1.11.0 h7a5bd25_0 conda-forge
libtiff 4.7.0 hfce79cd_1 conda-forge
libtorch 2.4.1 cpu_generic_h123b01e_0 conda-forge
libutf8proc 2.8.0 h1a8c8d9_0 conda-forge
libuv 1.49.0 hd74edd7_0 conda-forge
libwebp-base 1.4.0 h93a5062_0 conda-forge
libxcb 1.17.0 hdb1d25a_0 conda-forge
libxml2 2.12.7 h01dff8b_4 conda-forge
libzlib 1.3.1 h8359307_2 conda-forge
llvm-openmp 19.1.1 h6cdba0f_0 conda-forge
llvm-tools 17.0.6 h5090b49_2 conda-forge
llvmlite 0.43.0 py312ha9ca408_1 conda-forge
logical-unification 0.4.6 pyhd8ed1ab_0 conda-forge
lz4-c 1.9.4 hb7217d7_0 conda-forge
macosx_deployment_target_osx-arm64 11.0 h6553868_1 conda-forge
markdown-it-py 3.0.0 pyhd8ed1ab_0 conda-forge
markupsafe 3.0.1 py312h906988d_1 conda-forge
matplotlib 3.9.2 py312h1f38498_1 conda-forge
matplotlib-base 3.9.2 py312h9bd0bc6_1 conda-forge
mdurl 0.1.2 pyhd8ed1ab_0 conda-forge
minikanren 1.0.3 pyhd8ed1ab_0 conda-forge
ml_dtypes 0.5.0 py312hcd31e36_0 conda-forge
mpc 1.3.1 h8f1351a_1 conda-forge
mpfr 4.2.1 hb693164_3 conda-forge
mpmath 1.3.0 pyhd8ed1ab_0 conda-forge
multipledispatch 0.6.0 pyhd8ed1ab_1 conda-forge
munkres 1.1.4 pyh9f0ad1d_0 conda-forge
ncurses 6.5 h7bae524_1 conda-forge
networkx 3.4 pyhd8ed1ab_0 conda-forge
nomkl 1.0 h5ca1d4c_0 conda-forge
numba 0.60.0 py312h41cea2d_0 conda-forge
numpy 1.26.4 py312h8442bc7_0 conda-forge
numpyro 0.15.3 pyhd8ed1ab_0 conda-forge
nutpie 0.13.2 py312headafe2_0 conda-forge
openblas 0.3.27 openmp_h560b219_1 conda-forge
openjpeg 2.5.2 h9f1df11_0 conda-forge
openssl 3.3.2 h8359307_0 conda-forge
opt-einsum 3.4.0 hd8ed1ab_0 conda-forge
opt_einsum 3.4.0 pyhd8ed1ab_0 conda-forge
optax 0.2.3 pyhd8ed1ab_0 conda-forge
orc 2.0.2 h75dedd0_0 conda-forge
packaging 24.1 pyhd8ed1ab_0 conda-forge
pandas 2.2.3 py312hcd31e36_1 conda-forge
pango 1.54.0 h9ee27a3_2 conda-forge
pcre2 10.44 h297a79d_2 conda-forge
pillow 10.4.0 py312h8609ca0_1 conda-forge
pip 24.2 pyh8b19718_1 conda-forge
pixman 0.43.4 hebf3989_0 conda-forge
psutil 6.0.0 py312h024a12e_1 conda-forge
pthread-stubs 0.4 hd74edd7_1002 conda-forge
pyarrow-core 17.0.0 py312he20ac61_1_cpu conda-forge
pycparser 2.22 pyhd8ed1ab_0 conda-forge
pygments 2.18.0 pyhd8ed1ab_0 conda-forge
pymc 5.17.0 hd8ed1ab_0 conda-forge
pymc-base 5.17.0 pyhd8ed1ab_0 conda-forge
pyparsing 3.1.4 pyhd8ed1ab_0 conda-forge
pysocks 1.7.1 pyha2e5f31_6 conda-forge
pytensor 2.25.5 py312h3f593ad_0 conda-forge
pytensor-base 2.25.5 py312h02baea5_0 conda-forge
python 3.12.7 h739c21a_0_cpython conda-forge
python-dateutil 2.9.0 pyhd8ed1ab_0 conda-forge
python-graphviz 0.20.3 pyhe28f650_1 conda-forge
python-tzdata 2024.2 pyhd8ed1ab_0 conda-forge
python_abi 3.12 5_cp312 conda-forge
pytorch 2.4.1 cpu_generic_py312h40771f0_0 conda-forge
pytz 2024.1 pyhd8ed1ab_0 conda-forge
pyyaml 6.0.2 py312h024a12e_1 conda-forge
qhull 2020.2 h420ef59_5 conda-forge
re2 2023.09.01 h4cba328_2 conda-forge
readline 8.2 h92ec313_1 conda-forge
requests 2.32.3 pyhd8ed1ab_0 conda-forge
rich 13.9.2 pyhd8ed1ab_0 conda-forge
safetensors 0.4.5 py312he431725_0 conda-forge
scikit-learn 1.5.2 py312h387f99c_1 conda-forge
scipy 1.14.1 py312heb3a901_0 conda-forge
setuptools 75.1.0 pyhd8ed1ab_0 conda-forge
sigtool 0.1.3 h44b9a77_0 conda-forge
six 1.16.0 pyh6c4a22f_0 conda-forge
sleef 3.7 h7783ee8_0 conda-forge
snappy 1.2.1 hd02b534_0 conda-forge
sympy 1.13.3 pyh2585a3b_104 conda-forge
tabulate 0.9.0 pyhd8ed1ab_1 conda-forge
tapi 1300.6.5 h03f4b80_0 conda-forge
threadpoolctl 3.5.0 pyhc1e730c_0 conda-forge
tk 8.6.13 h5083fa2_1 conda-forge
toolz 1.0.0 pyhd8ed1ab_0 conda-forge
tornado 6.4.1 py312h024a12e_1 conda-forge
tqdm 4.66.5 pyhd8ed1ab_0 conda-forge
typing-extensions 4.12.2 hd8ed1ab_0 conda-forge
typing_extensions 4.12.2 pyha770c72_0 conda-forge
tzdata 2024b hc8b5060_0 conda-forge
urllib3 2.2.3 pyhd8ed1ab_0 conda-forge
wheel 0.44.0 pyhd8ed1ab_0 conda-forge
xarray 2024.9.0 pyhd8ed1ab_1 conda-forge
xarray-einstats 0.8.0 pyhd8ed1ab_0 conda-forge
xorg-libxau 1.0.11 hd74edd7_1 conda-forge
xorg-libxdmcp 1.1.5 hd74edd7_0 conda-forge
xz 5.2.6 h57fd34a_0 conda-forge
yaml 0.2.5 h3422bc3_2 conda-forge
zipp 3.20.2 pyhd8ed1ab_0 conda-forge
zlib 1.3.1 h8359307_2 conda-forge
zstandard 0.23.0 py312h15fbf35_1 conda-forge
zstd 1.5.6 hb46c0d2_0 conda-forge
That you for the bug report. I had seen something possibly related recently, but didn't manage to find an example in a smaller model. This example should make it much easier to find the problem.
Right now you can work around the issue by freezing the pymc model:
from pymc.model.transform.optimization import freeze_dims_and_data
trace = nutpie.sample(nutpie.compile_pymc_model(freeze_dims_and_data(model), **kwargs))
Well, that was easier than I though, and won't be hard to fix. The problem is that the argument name for the point in parameter space is x and if any shared variable (like the data) is also called x this will give an argument name collision.
@aseyboldt : Thanks a lot for the very helpful reply. I renamed 'x' -> 'X' and now things are working.
This is really good to know, but probably something that should either be fixed by ensuring that a unique name for the point in parameter space is used, or forbidding 'x' as name in the model (which would be a bit cumbersome, since predictors are often denoted by 'x').
On the upside using JAX gives a very nice speedup on my machine (Apple M1) :-).
Yes, definitely needs a fix, I'll push one soon.
Out of curiosity (I don't have a apple), could you do me a small favor and run this with jax and numba and tell me what the compile and the runtime is each time?
jax
frame = generate_data(num_samples=10_000)
model = make_model(frame)
kwargs = dict(backend="jax", gradient_backend="jax")
t0 = time.time()
compiled = nutpie.compile_pymc_model(model, **kwargs)
print(f"compile time: {time.time() - t0}")
t0 = time.time()
trace = nutpie.sample(compiled)
t = time.time() - t0
print(f"Time for nutpie (compiled, {kwargs=}) sampling is {t=:.3f}s.")
summary = az.summary(trace, var_names=["beta", "sigma"], round_to=4).loc[:, ["mean", "hdi_3%", "hdi_97%", "ess_bulk", "ess_tail", "r_hat"]]
print(summary)
numba
frame = generate_data(num_samples=10_000)
model = make_model(frame)
kwargs = dict(backend="numba", gradient_backend="jax")
t0 = time.time()
compiled = nutpie.compile_pymc_model(model, **kwargs)
print(f"compile time: {time.time() - t0}")
t0 = time.time()
trace = nutpie.sample(compiled)
t = time.time() - t0
print(f"Time for nutpie (compiled, {kwargs=}) sampling is {t=:.3f}s.")
summary = az.summary(trace, var_names=["beta", "sigma"], round_to=4).loc[:, ["mean", "hdi_3%", "hdi_97%", "ess_bulk", "ess_tail", "r_hat"]]
print(summary)
jax
compile time for kwargs={'backend': 'jax', 'gradient_backend': 'jax'}: 1.016s
Time for nutpie (compiled, kwargs={'backend': 'jax', 'gradient_backend': 'jax'}) sampling is t=2.094s.
numba
compile time for kwargs={'backend': 'numba', 'gradient_backend': 'pytensor'}: 3.835s
Time for nutpie (compiled, kwargs={'backend': 'numba', 'gradient_backend': 'pytensor'}) sampling is t=0.564s.
I hope that helps. I have another follow up question: While I observe a great speed-up when using the JAX backend on my M1 Apple machine, I observe significantly slower sampling with the JAX backend compared to Numba/Pytensor when running on a Google Cloud VM with a lot more cores (32) and memory. This is for a hierarchical linear regression with thousands of groups and a couple of predictors.
On the VM sampling with the "jax" backend is about 30% slower compared to the "numba" backend. Specifically I observe that for the "numba" backend I get a couple of (4-8) thread/CPU bars in htop with 100%, while for the JAX backend all 32 bars show "some occupancy at less than 50%".
If you have any ideas/insights what could cause this and also how to ensure best performance, then I'd be glad for any suggestions.
Thanks for the numbers :-)
First, I think it is important to distinguish compile time and sampling time. The numbers you just gave me show that the numba backend samples faster on the mac as well, only the compile time is much larger. If the model get's bigger the compile time will play less of a role, because it doesn't depend much on the data size.
I think what you observe with the jax backend is an issue with how the jax backend currently works: The chains run in different threads that are controlled in the rust code. With the numba backend the python interpreter is only used to generate the logp function, all sampling happens without any involvement of python. But doing this with jax is currently much harder (I hope this will change in the not too distant future though). jax compiles the logp function, but I can't easily access this compiled function from rust. So instead I have to call a python function that then calls the compiled jax function. While a bit silly, that wouldn't be too bad if python didn't have the GIL. But the GIL ensures that only one thread (ie chain) can use the python interpreter at the same time. So each logp function evaluation does something like the following:
- Acquire the gil (ie wait until no other thread is holding the GIL)
- start the jax logp function
- release the gil
- do the actual logp fucniton evaluation (while not holding the gil)
- acquire the gil
- get access to the results
- release the gil
If the computation of the logp function takes a long time, and there aren't that many threads, then most of the time only one or even no thread will hold the gil, because each threads spends most of its time in the "do the actual logp function evaluation" phase, and all is good. But if the logp function evaluation is relatively quick, then more than one thread will try to acquire the gil at the same time, and this means that the threads sit around waiting. Ie "low occupancy".
There are two things that might make this situation better in the future:
- python 3.13 has a new build without the gil. That isn't supported by jax and many other libraries at the time, but I hope this will change quickly.
- Jax might make it easier to call the compiled functions from other languages, so that I don't have to go through python to call the compiled function, which would avoid the whole problem to begin with.
In the meantime: If the cores of your machine aren't used well, you can at least try to limit the number of threads that run at the same time by setting the cores argument to sample to something smaller. This can reduce the lock contention and give you a modest speedup. It won't really fix the problem though...
If you are willing to go to some extra lengths: You can start multiple separate processes that sample your model (with different seeds!) and then combine the traces. This is much more annoying, but should completely avoid the lock contention. In that case you can run into other issues however, for instance if each process tries to use all available cores on the machine. Fixing that would then require using threadpoolctl and/or tasksel or some jax environment variable flags.
I hope that helps to clear it up a bit :-)
Thanks @aseyboldt this is really helpful. Do you know of a minimal example for the "start multilple separate processes" approach. I have seen https://discourse.pymc.io/t/harnessing-multiple-cores-to-speed-up-fits-with-small-number-of-chains/7669 where the idea is to concatenate multiple smaller chains to more efficiently harness the CPUs on a machine.
I'd e.g. try to use joblib for that but I am not sure how much that interferes with the PyMC and nutpie internals. If you have any pointers I'd be very glad to look into it.
Btw: with num_samples = 100_000 the numbers look like this on Apple M1
jax
compile time for kwargs={'backend': 'jax', 'gradient_backend': 'jax'}: 0.864s
Time for nutpie (compiled, kwargs={'backend': 'jax', 'gradient_backend': 'jax'}) sampling is t=5.842s.
numba
compile time for kwargs={'backend': 'numba', 'gradient_backend': 'pytensor'}: 3.874s
Time for nutpie (compiled, kwargs={'backend': 'numba', 'gradient_backend': 'pytensor'}) sampling is t=19.923s.
So JAX is a lot faster for sampling - which also matches my observation for a hierarchical linear model.
For sampling in separate processes:
# At the very start...
import os
os.environ["JOBLIB_START_METHOD"] = "forkserver"
import joblib
from joblib import parallel_config, Parallel, delayed
import arviz
def run_chain(data, idx, seed):
model = make_model(data)
seeds = np.random.SeedSequence(seed)
seed = np.random.default_rng(seeds.spawn(idx + 1)[-1]).integers(2 ** 63)
compiled = nutpie.compile_pymc_model(model, backend="jax", gradient_backend="jax")
trace = nutpie.sample(compiled, seed=seed, chains=1, progress_bar=False)
return trace.assign_coords(chain=[idx])
with parallel_config(n_jobs=10, prefer='processes'):
traces = Parallel()(delayed(run_chain)(frame, i, 123) for i in range(10))
trace = arviz.concat(traces, dim="chain")
This comes with quite a bit of overhead (mostly constant though), so probably not worth it for smaller models.
Funnily enough, I see big differences between
# Option 1
mu = (beta * x).sum(axis=-1)
# Option 2
mu = x @ beta
And jax and numba react quite differently. Maybe an issue with the blas config? What blas implementation are you using?
(on conda-forge you can choose it as explained here: https://conda-forge.org/docs/maintainer/knowledge_base/#switching-blas-implementation) I think on M1 accelerate is usually the fastest).
Thanks a lot for the again very helpful suggestions. I will benchmark the two versions of the "dot-product" to see whether I observe different performance.
Regarding BLAS
On Apple-M1 I have
blas 2.124 openblas conda-forge
blas-devel 3.9.0 24_osxarm64_openblas conda-forge
libblas 3.9.0 24_osxarm64_openblas conda-forge
libcblas 3.9.0 24_osxarm64_openblas conda-forge
liblapack 3.9.0 24_osxarm64_openblas conda-forge
liblapacke 3.9.0 24_osxarm64_openblas conda-forge
libopenblas 0.3.27 openmp_h517c56d_1 conda-forge
openblas 0.3.27 openmp_h560b219_1 conda-forge
and on the VM
blas 2.120 mkl conda-forge
blas-devel 3.9.0 20_linux64_mkl conda-forge
libblas 3.9.0 20_linux64_mkl conda-forge
libcblas 3.9.0 20_linux64_mkl conda-forge
I can try to use the accelerate BLAS. But I am more curious to speed up things on the VM now.