pymc-experimental
pymc-experimental copied to clipboard
sample_smc
Hi there,
I was testing pymc_experimental/inference/smc/sampling.py and noticed the following issues:
- the inference doesn't seem to like pm.Dirichlet, with a shape error at tmp = logp_fn(*[p.squeeze() for p in particles])[0]
- arviz_from_particles doesn't seem to like RVs with shape=(1,)
- the conversion from inferencedata to netCDF fails because the integrations is neither int nor np.array
- the inferencedata doesn't have the marginal likelihood, do you think it will be implemented in the future or it's just not possible?
Thanks a lot for the SMC blackjax implementation, it's very useful!
Cheers, VIan
PS: here's some code that produces the error
` real_a = 0.2 real_b = 2 x = np.linspace(1, 100) y = real_a * x + real_b + np.random.normal(0, 2, len(x))
with pm.Model() as model: a = pm.Normal("a", mu=10, sigma=10) b = pm.Normal("b", mu=10, sigma=10) # either of the following lines produces an error # c = pm.Normal("c", mu=10, sigma=10, shape=(1,)) # d = pm.Dirichlet("d", [1, 1])
trace = sample_smc(
n_particles=1000,
kernel="HMC",
inner_kernel_params={
"step_size": 0.01,
"integration_steps": 20,
},
iterations_to_diagnose=10,
target_essn=0.5,
num_mcmc_steps=10,
)
`
Hi, I can tackle this could someone assign the issue to me?
@myravian could you try your example running it from this branch? I may have a fix https://github.com/ciguaran/pymc-experimental/tree/ciguaran_fix_smc_bj . Also super interested to know what are you using SMC for, it would be great if it would become an example notebook on how to use it!. let me know.
Unfortunately I still have the same error message:
File "/local/home/vleboute/work/MULTIGRIS/mgris/sampling_smc_ciguaran.py", line 150, in sample_smc_blackjax
total_iterations, particles, diagnosis = inference_loop(
^^^^^^^^^^^^^^^
File "/local/home/vleboute/work/MULTIGRIS/mgris/sampling_smc_ciguaran.py", line 267, in inference_loop
n_iter, final_state, _, diagnosis = jax.lax.while_loop(
^^^^^^^^^^^^^^^^^^^
File "/local/home/vleboute/work/MULTIGRIS/mgris/sampling_smc_ciguaran.py", line 262, in one_step
state, info = kernel.step(subk, state)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/blackjax/smc/adaptive_tempered.py", line 167, in step_fn
return kernel(
^^^^^^^
File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/blackjax/smc/adaptive_tempered.py", line 101, in kernel
return tempered_kernel(rng_key, state, num_mcmc_steps, lmbda, mcmc_parameters)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/blackjax/smc/tempered.py", line 143, in kernel
smc_state, info = smc.base.step(
^^^^^^^^^^^^^^
File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/blackjax/smc/base.py", line 140, in step
particles, update_info = update_fn(keys, particles)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/blackjax/smc/tempered.py", line 131, in mcmc_kernel
state = mcmc_init_fn(position, tempered_logposterior_fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/blackjax/mcmc/hmc.py", line 89, in init
logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/blackjax/smc/tempered.py", line 126, in tempered_logposterior_fn
logprior = logprior_fn(position)
^^^^^^^^^^^^^^^^^^^^^
File "/local/home/vleboute/work/MULTIGRIS/mgris/sampling_smc_ciguaran.py", line 380, in logp_fn_wrap
return logp_fn(*particles)[0]
^^^^^^^^^^^^^^^^^^^
File "/tmp/tmpoc516ktt", line 29, in jax_funcified_fgraph
tensor_variable_13 = dimshuffle_1(d_simplex_)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/pytensor/link/jax/dispatch/elemwise.py", line 69, in dimshuffle
res = jnp.transpose(x, op.transposition)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 681, in transpose
return lax.transpose(a, axes_)
^^^^^^^^^^^^^^^^^^^^^^^
TypeError: transpose permutation isn't a permutation of operand dimensions, got permutation (0,) for operand shape (1000, 1).
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
I don't have a straightforward and simple illustration of the way I use SMC, but the gist of it is that I ran an astrophysical code to compute predictions corresponding to several 100 thousand parameter sets. Based on a set of observables I try to infer the parameters. Of course there are various issues such as regularity/completeness of grid and interpolation, but the main issue is the complex, multi-modal posterior distributions that we expect. From all the proof of concept and validation tests we did, SMC has been a great way to probe the prior space and to handle such difficult posteriors (provided the kernel parameters are well tuned). I'm by no means an expert in statistics and I rely a lot on empirical knowledge so I'm sure I'm not doing everything right though...!
Could you share a full python file that reproduces the error? I've run the example you posted at the very beginning and it does work for me 🤔 .
Here would be the script: `import pymc as pm
from sampling_smc_ciguaran import sample_smc_blackjax as sample_smc
with pm.Model() as model: c = pm.Normal("c", mu=10, sigma=10, shape=(1,)) d = pm.Dirichlet("d", [1, 1])
trace = sample_smc(
n_particles=1000,
kernel="HMC",
inner_kernel_params={
"step_size": 0.01, # small values better
"integration_steps": 20,
},
iterations_to_diagnose=10,
target_essn=0.5,
num_mcmc_steps=10,
)
` Maybe it has to do with the blackjax/jax versions (1.1.0/0.4.21 in my system)
Hi! so I was able to run the example you just shared via installing pymc-experimental from the branch.
pip install git+https://github.com/ciguaran/pymc-experimental@ciguaran_fix_smc_bj
is it possible that you are still using pymc-experimental from master?
You're right, I was not using the proper versions, just tested it and it seems to work fine, thanks for the modifications!