MolDiff
MolDiff copied to clipboard
posterior calculation
Hi, I have some question regarding the posterior sampling code. In my understandings, in the posterior sampling process (during training and generation both)
q(x_{t-1}|x_{t}, x_0)
the x_0 is the predicted type (would be provided as logprobability when generation) and the x_{t} is the current (generated x_{t}) (a type sampled by the previously predicted probability.)
However, it seems that the code in MolDiff takes x_t as the probability obtained from the previous generation step. Could you please clarify if there are any different processing steps that I may have misunderstood? Thank you.
log_node_type = self.node_transition.q_v_posterior(log_node_recon, log_node_type, time_step, batch_node, v0_prob=True) node_type_prev = log_sample_categorical(log_node_type)
def q_v_posterior(self, log_v0, log_vt, t, batch, v0_prob): # q(vt-1 | vt, v0) = q(vt | vt-1, x0) * q(vt-1 | x0) / q(vt | x0) t_minus_1 = t - 1 t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1) # Remove negative values, will not be used anyway for final decoder
fact1 = extract(self.transpopse_q_onestep_mats, t, batch, ndim=1)
# class_vt = log_vt.argmax(dim=-1)
# fact1 = fact1[torch.arange(len(class_vt)), class_vt]
fact1 = torch.einsum('bj,bjk->bk', torch.exp(log_vt), fact1) # (batch, N)
if not v0_prob: # log_v0 is directly transformed to onehot
fact2 = extract(self.q_mats, t_minus_1, batch, ndim=1)
class_v0 = log_v0.argmax(dim=-1)
fact2 = fact2[torch.arange(len(class_v0)), class_v0]
else: # log_v0 contains the probability information
fact2 = extract(self.q_mats, t_minus_1, batch, ndim=1) # (batch, N, N)
fact2 = torch.einsum('bj,bjk->bk', torch.exp(log_v0), fact2) # (batch, N)
ndim = log_v0.ndim
if ndim == 2:
t_expand = t[batch].unsqueeze(-1)
elif ndim == 3:
t_expand = t[batch].unsqueeze(-1).unsqueeze(-1)
else:
raise NotImplementedError('ndim not supported')
out = torch.log(fact1 + self.eps).clamp_min(-32.) + torch.log(fact2 + self.eps).clamp_min(-32.)
out = out - torch.logsumexp(out, dim=-1, keepdim=True)
out_t0 = log_v0
out = torch.where(t_expand == 0, out_t0, out)
return out