pymc4 copied to clipboard
Error in transformed variables with manual model vectorization
There seems to be an error when using manual model vectorization for transformed variables when the event shape of the variable is larger than 1. A minimal example would be:
import numpy as np
import pymc4 as pm
import tensorflow as tf
means = np.random.random((3,1))*5+5
noise = np.random.random((3, 10))
data = means + noise
data = data.astype('float32')
# We want to infer the means of the data:
def model():
means = yield pm.HalfNormal(name='means', loc= 0, scale = 10, event_stack=3)
means = tf.repeat(tf.expand_dims(means, axis=-1), axis=-1, repeats=10)
likelihood = yield pm.Normal(name='likeli', loc=means, scale = 5, observed=data,
trace = pm.sample(model(), num_samples=50, burn_in=200, use_auto_batching=False, num_chains=2)
print(np.median(trace.posterior['model/means'], axis=(0,1)))
which leads to the following shape error:
ValueError: Dimensions must be equal, but are 2 and 3 for '{{node add_2}} = AddV2[T=DT_FLOAT](add_1, mul)' with input shapes: [2], [2,3].
When the pm.Halfnormal
is replaced by pm.Normal
it works without problems.
If I understood the organization of the source code correctly, this error is due to the fact that the correct number of dimensions of the event shape is not passed to inverse_log_det_jacobian
and forward_log_det_jacobian
of tensorflow probability, for example exactly here: Somehow the Transform class should have also have the number of dimensions of the event_shape as an attribute, to be able to calculate the determinant of the Jacobian correctly.
But eventually, I am using PyMC4 wrongly and there is another way to specify the model...
Thanks for reporting this @jdehning. I can reproduce the exception you are running into and can confirm that its coming from the gradient computation. I still have to investigate more before confirming that the culprits are inverse_log_det_jacobian
and forward_log_det_jacobian
, but what you posted was very helpful to pinpoint the cause thus far. I'm not sure when I'll get a chance to fix this. Maybe @ferrine has some time to look into the problem a bit more.
Hmm, I can have a look after weekends
Looks like I failed to fix the error quickly, I've tried to figure out if it is the wrong usage of reinterpreted_batch_ndims=2
and similar arguments but it was not successful.
I think I have some ideas on how to fix it. Probably next week I will have some time to look deeper into it and open a PR, if you are okay with it...