code error in q_v_posterior q(v_t| v_t-1) ?
Hi,guanjq,
I find a code error in self.q_v_pred_one_timestep(log_vt, t, batch) of q_v_posterior function, which is for calculating the q(v_t| v_t-1) ??
source code in gitub :
# atom type generative process def q_v_posterior(self, log_v0, log_vt, t, batch): # q(vt-1 | vt, v0) = q(vt | vt-1, x0) * q(vt-1 | x0) / q(vt | x0) t_minus_1 = t - 1 # Remove negative values, will not be used anyway for final decoder t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1) log_qvt1_v0 = self.q_v_pred(log_v0, t_minus_1, batch) unnormed_logprobs = log_qvt1_v0 + self.q_v_pred_one_timestep(log_vt, t, batch) log_vt1_given_vt_v0 = unnormed_logprobs - torch.logsumexp(unnormed_logprobs, dim=-1, keepdim=True) return log_vt1_given_vt_v0
Is there error in "unnormed_logprobs = log_qvt1_v0 + self.q_v_pred_one_timestep(log_vt, t, batch)" ? and, should be change to "unnormed_logprobs = log_qvt1_v0 + self.q_v_pred_one_timestep(log_qvt1_v0 , t, batch)" ?
Best,