mcx
mcx copied to clipboard
Incorrect results when sampling from the prior
While going through Statistical Rethinking I wanted to execute a prior-predictive simulation, but the results did not match the textbook example, see below.
What's more, I played with some other synthetic examples and they also give unintuitive results, see further down.
Examples
Example from the rethinking
Code
import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx
from mcx import distributions as dist
from mcx import sample_joint
@mcx.model
def model():
μ <~ dist.Normal(178, 20)
σ <~ dist.Uniform(0, 50)
h <~ dist.Normal(μ, σ)
return h
rng_key = jax.random.PRNGKey(0)
prior_predictive = sample_joint(
rng_key=rng_key,
model=model,
model_args=(),
num_samples=10_000
)
fig, axes = plt.subplots(2, 2, figsize=(7, 5), dpi=128)
axes = axes.reshape(-1)
sns.kdeplot(prior_predictive["μ"], ax=axes[0])
sns.kdeplot(prior_predictive["σ"], ax=axes[1])
sns.kdeplot(prior_predictive["h"], ax=axes[2])
plt.tight_layout()
Result
Expected
Synthetic example 1
In this example I sample an offset
from Uniform(0, 1)
.
Then I sample from Uniform(12 - offset, 12 + offset)
So I expect my samples to be distributed in range [11, 13]
But I get samples in range [-15, 15]
Code
import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx
from mcx import distributions as dist
from mcx import sample_joint
@mcx.model
def example_1():
center = 12
offset <~ dist.Uniform(0, 1)
low = (center - offset)
high = (center + offset)
outcome <~ dist.Uniform(low, high)
rng_key = jax.random.PRNGKey(0)
prior_predictive = sample_joint(
rng_key=rng_key,
model=example_1,
model_args=(),
num_samples=10_000
)
ax = sns.kdeplot(prior_predictive["outcome"]);
ax.set_title("Outcome");
Result
Synthetic example 2
This is the same example as above, but center
variable is passed as argument, not hardcoded, and results are different (although still not in range [11, 13]
Code
import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx
from mcx import distributions as dist
from mcx import sample_joint
@mcx.model
def example_2(center):
offset <~ dist.Uniform(0, 1)
low = (center - offset)
high = (center + offset)
outcome <~ dist.Uniform(low, high)
rng_key = jax.random.PRNGKey(0)
prior_predictive = sample_joint(
rng_key=rng_key,
model=example_2,
model_args=(12, ),
num_samples=10_000
)
ax = sns.kdeplot(prior_predictive["outcome"]);
ax.set_title("Outcome");
Result
Expectation
For the examples 1
and 2
, here's what I'd expect to get:
Environment
Linux-5.8.0-44-generic-x86_64-with-glibc2.10
Python 3.8.5 (default, Sep 4 2020, 07:30:14)
[GCC 7.3.0]
JAX 0.2.8
NetworkX 2.5
JAXlib 0.1.58
mcx 2a2b94801e68d94d86826863eeee80f0b84c390d
Hi @rlouf
I've looked into this a bit more and identified two issues:
- In the
mcx
models, the same random key seems to be used for multiple distributions, giving incorrect results. - The subtraction Op and negation Op seem broken
Please find the examples of the two issues in the notebook: https://gist.github.com/elanmart/9ab0ba21f282f6b24d972cbfb76b4578
Hope this is helpful
Hi @elanmart,
Thank you for taking the time to share this with me! Regarding what you identified:
- Indeed, I just noticed that recently. It is indeed problematic if you use the same distribution more than once in the model. This should be corrected soon.
- In what sense? Would you mind pasting the output of
print(example_1.sample_joint_src)
andprint(example_2.sample_joint_src)
?
Thanks for the answer! I was wondering how I can inspect the models, sample_joint_src
reveals what goes wrong indeed!
The following model
@mcx.model
def example_2_mcx_v1():
offset <~ dist.Uniform(0, 5)
low = 12 - offset
outcome <~ dist.Uniform(low, 12)
return outcome
is transformed into
def example_2_mcx_v1_sample_forward(rng_key):
offset = dist.Uniform(0, 5).sample(rng_key)
low = offset - 12
outcome = dist.Uniform(low, 12).sample(rng_key)
forward_samples = {'offset': offset, 'outcome': outcome}
return forward_samples
Notice how
low = 12 - offset
became
low = offset - 12
EDIT
The issue is not limited to constants. The arguments in subtraction are switched to match the order in which they were defined, so
A <~ ...
B <~ ...
B - A
becomes
A - B
and so the model here
@mcx.model
def example():
A <~ dist.Normal(0, 1)
B <~ dist.Normal(0, 2)
μ = B - A
Y <~ dist.Normal(μ, 1)
return Y
becomes
def example_sample_forward(rng_key):
B = dist.Normal(0, 2).sample(rng_key)
A = dist.Normal(0, 1).sample(rng_key)
μ = A - B
Y = dist.Normal(μ, 1).sample(rng_key)
forward_samples = {'A': A, 'B': B, 'Y': Y}
return forward_samples
Ah, and also regarding point 1. (same rng_key used many times): is there any simple workaround I could use as a temporary solution, however hacky?
That's strange regarding A-B
, I identified the problem 10 days ago and I thought I'd fixed it. Are you running the latest version (latest commit)?
Unfortunately no workaround for the rng_key
but I can try to push a fix next week. I'll make sure it works on these examples. In the meantime you can keep moving forward, checking the source code each time there's something weird. You'd just have to re-run your notebooks once the fixes are made.
Now I see how convenient compiling to a python function is for debugging 😄 Thank you for dealing with the teething problems here, it is really helpful for us.
OK, so my poetry.lock
file indicated that I have the latest commit, but after clean re-install the issue is indeed resolved...
I'm really sorry for generating noise :cry:
Do you want me to close this ticket and open a clean one for rng_key
topic? Out of curiosity -- what is the fix you envision there? Adding _, key = jax.random.split(key)
statement to the graph after each sample()
call? Or is there a nicer solution?
Thank you for dealing with the teething problems
No worries, I would love to understand the compiler a bit better to be able to debug similar issues myself.
I'm really sorry for generating noise :cry:
No worries, you're really helpful :)
Do you want me to close this ticket and open a clean one for
rng_key
topic?
Yes please! Leave this one open until we solve the issue completely.
Out of curiosity -- what is the fix you envision there? Adding
_, key = jax.random.split(key)
statement to the graph after eachsample()
call? Or is there a nicer solution?
So that would be the quick and dirty solution. I think that I might instead generate as many keys as needed at the beginning of the function.
No worries, I would love to understand the compiler a bit better to be able to debug similar issues myself.
Well now you know that you can at least print the code generated by the compiler. It's a good start point.
@elanmart - with regards to the compiler - I made some (not really organized) notes here, some of which I suppose are (or will soon be invalid) invalid after the <~
operator is phased out. In any case maybe they would be helpful to you!
some of which I suppose are (or will soon be invalid) invalid after the
<~
operator is phased out.
Actually the general principle with stay exactly the same.
Thank you @tblazina ! This looks extremely useful, will go through it over the weekend!
Just to let you know, I'll make some time to work on this and the other issue one my NUTS PR is merged on BlackJAX (which means MCX will support NUTS). How is the implementation on SR going?
Thanks for the update! Looking forward to the NUTS sampler as well.
I've decided to first go through the theory, and then make a second pass implementing the examples.
I've just finished the book, so I'm going back to the code, which hopefully should go faster now.
There were a few places in the book where some advanced STAN featuers were used. I'm a bit worried about those, but we'll see how it goes.
Great! If you remember which ones don't hesitate to open issues now.