MolDiff icon indicating copy to clipboard operation
MolDiff copied to clipboard

posterior calculation

Open oneoftwo opened this issue 4 months ago • 0 comments

Hi, I have some question regarding the posterior sampling code. In my understandings, in the posterior sampling process (during training and generation both)

q(x_{t-1}|x_{t}, x_0)

the x_0 is the predicted type (would be provided as logprobability when generation) and the x_{t} is the current (generated x_{t}) (a type sampled by the previously predicted probability.)

However, it seems that the code in MolDiff takes x_t as the probability obtained from the previous generation step. Could you please clarify if there are any different processing steps that I may have misunderstood? Thank you.

log_node_type = self.node_transition.q_v_posterior(log_node_recon, log_node_type, time_step, batch_node, v0_prob=True) node_type_prev = log_sample_categorical(log_node_type)

def q_v_posterior(self, log_v0, log_vt, t, batch, v0_prob): # q(vt-1 | vt, v0) = q(vt | vt-1, x0) * q(vt-1 | x0) / q(vt | x0) t_minus_1 = t - 1 t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1) # Remove negative values, will not be used anyway for final decoder

    fact1 = extract(self.transpopse_q_onestep_mats, t, batch, ndim=1)
    # class_vt = log_vt.argmax(dim=-1)
    # fact1 = fact1[torch.arange(len(class_vt)), class_vt]
    fact1 = torch.einsum('bj,bjk->bk', torch.exp(log_vt), fact1)  # (batch, N)
    
    if not v0_prob:  # log_v0 is directly transformed to onehot
        fact2 = extract(self.q_mats, t_minus_1, batch, ndim=1)
        class_v0 = log_v0.argmax(dim=-1)
        fact2 = fact2[torch.arange(len(class_v0)), class_v0]
    else:  # log_v0 contains the probability information
        fact2 = extract(self.q_mats, t_minus_1, batch, ndim=1)  # (batch, N, N)
        fact2 = torch.einsum('bj,bjk->bk', torch.exp(log_v0), fact2)  # (batch, N)
    
    ndim = log_v0.ndim
    if ndim == 2:
        t_expand = t[batch].unsqueeze(-1)
    elif ndim == 3:
        t_expand = t[batch].unsqueeze(-1).unsqueeze(-1)
    else:
        raise NotImplementedError('ndim not supported')
    
    out = torch.log(fact1 + self.eps).clamp_min(-32.) + torch.log(fact2 + self.eps).clamp_min(-32.)
    out = out - torch.logsumexp(out, dim=-1, keepdim=True)
    out_t0 = log_v0
    out = torch.where(t_expand == 0, out_t0, out)
    return out

oneoftwo avatar Feb 29 '24 06:02 oneoftwo