targetdiff icon indicating copy to clipboard operation
targetdiff copied to clipboard

code error in q_v_posterior q(v_t| v_t-1) ?

Open FeilongWuHaa opened this issue 1 year ago • 0 comments

Hi,guanjq,

I find a code error in self.q_v_pred_one_timestep(log_vt, t, batch) of q_v_posterior function, which is for calculating the q(v_t| v_t-1) ??

source code in gitub : # atom type generative process def q_v_posterior(self, log_v0, log_vt, t, batch): # q(vt-1 | vt, v0) = q(vt | vt-1, x0) * q(vt-1 | x0) / q(vt | x0) t_minus_1 = t - 1 # Remove negative values, will not be used anyway for final decoder t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1) log_qvt1_v0 = self.q_v_pred(log_v0, t_minus_1, batch) unnormed_logprobs = log_qvt1_v0 + self.q_v_pred_one_timestep(log_vt, t, batch) log_vt1_given_vt_v0 = unnormed_logprobs - torch.logsumexp(unnormed_logprobs, dim=-1, keepdim=True) return log_vt1_given_vt_v0

Is there error in "unnormed_logprobs = log_qvt1_v0 + self.q_v_pred_one_timestep(log_vt, t, batch)" ? and, should be change to "unnormed_logprobs = log_qvt1_v0 + self.q_v_pred_one_timestep(log_qvt1_v0 , t, batch)" ?

Best,

FeilongWuHaa avatar Sep 18 '24 08:09 FeilongWuHaa