DiGress
DiGress copied to clipboard
Reporting KL divergence loss for training step
Thank you for releasing the code. I am using a custom dataset with 10k graphs. I tried to update the code to include the kl divergence during training to check if there is overfitting on the smaller dataset. While the PosMSE seems fine, the results for E_kl and X_kl always give a nan for the training samples. Can you please tell me if there is something wrong with my approach?
self.train_metrics = torchmetrics.MetricCollection([custom_metrics.PosMSE(), custom_metrics.XKl(), custom_metrics.EKl()])
In my training_step, I invoke
nll, log_dict = self.compute_train_nll_loss(pred, z_t, clean_data=dense_data)
Finally, the method definition is
def compute_train_nll_loss(self, pred, z_t, clean_data):
node_mask = z_t.node_mask
t_int = z_t.t_int
s_int = t_int - 1
logger_metric = self.train_metrics
# 1.
N = node_mask.sum(1).long()
log_pN = self.node_dist.log_prob(N)
# 2. The KL between q(z_T | x) and p(z_T) = Uniform(1/num_classes). Should be close to zero.
kl_prior = self.kl_prior(clean_data, node_mask)
# 3. Diffusion loss
loss_all_t = self.compute_Lt(clean_data, pred, z_t, s_int, node_mask, logger_metric)
# Combine terms
nlls = - log_pN + kl_prior + loss_all_t
# Update NLL metric object and return batch nll
nll = self.train_nll(nlls) # Average over the batch
log_dict = {"train kl prior": kl_prior.mean(),
"Estimator loss terms": loss_all_t.mean(),
"log_pn": log_pN.mean(),
'train_nll': nll}
return nll, log_dict
Any help would be highly appreciated.
Best, Chinmay