log_prob of blocked programs
Hello,
First, thanks for developing such an amazing package. I am newbie to oryx and was playing around with its functionalities, perhaps naively I had been attempting to evalute log_probs of blocker porblems, as below:
from jax.random import split
from oryx.core import ppl
import tensorflow_probability.substrates.jax.distributions as tfd
def latent_normal(key):
z_key,x_key= split(key)
z=ppl.random_variable(tfd.Normal(0,1),name="z")(z_key)
return ppl.random_variable(tfd.Normal(z,1e-1),name="x")(x_key)
blocked=ppl.block(latent_normal,names=["z"])
ppl.joint_log_prob(blocked)({"x":10})
However, it returns: { "name": "ValueError", "message": "Cannot compute log_prob of function.", "stack": "--------------------------------------------------------------------------- ValueError Traceback (most recent call last) Cell In[2], line 12 8 return ppl.random_variable(tfd.Normal(z,1e-1),name="x")(x_key) 11 blocked=ppl.block(latent_normal,names=["z"]) ---> 12 ppl.joint_log_prob(blocked)({"x":10})
File /opt/conda/lib/python3.11/site-packages/oryx/core/interpreters/log_prob.py:71, in log_prob.
File /opt/conda/lib/python3.11/site-packages/oryx/core/interpreters/log_prob.py:128, in log_prob_jaxpr(jaxpr, constcells, flat_incells, flat_outcells) 118 _, final_log_prob = propagate.propagate( 119 InverseAndILDJ, 120 log_prob_rules, (...) 125 reducer=reducer, 126 initial_state=0.) 127 if final_log_prob is failed_log_prob: --> 128 raise ValueError('Cannot compute log_prob of function.') 129 return final_log_prob
ValueError: Cannot compute log_prob of function." }
Am I missing something?
Thanks again
Very Best Giovanni
Hi @jhn-nt. I'm not a contributor or maintainer but I think you might be missing the following. If you look at this description of block it says:
"The block transformation takes in a program and a sequence of names and returns a program that behaves identically except that in downstream transformations (likejoint_sample), the provided names are ignored."
When you run blocked=ppl.block(latent_normal,names=["z"]) you are telling downstream transformations (like joint_log_prob) to ignore the variable with name 'z'. Because the variable with name 'x' depends on the variable with name 'z', which you blocked, joint_log_prob is unable to "see" what the value of 'z' is to get the distribution over 'x' to then compute the log_prob of 10 under that distribution. That's why you're seeing the error.
In general, I would guess you are unable to compute samples or log probabilities of a program with respect to a variable that depends on variables you previously blocked. @sharadmv let me know if I missed anything; I am not sure if there is a principled way to catch this sort of error and display a more helpful error message -- I may look into it. And @jhn-nt let me know if you have any further questions! Hope that helps.
Thank very much! this really helped
Giovanni