GON
GON copied to clipboard
Where can I obtain the validation reconstruction loss?
Hi,
In the paper, I see there is a comparison between different models. Where can I find the corresponding code that calculates the validation reconstruction loss? Thank you!
I am not expert on this model but I will make my own comment. I think you can calculate it with the following line of code
((g - x)**2).sum() / x.shape[0]
where x is the ground truth batch data and g is the output of the model (after feeding the negative gradient as the latent). Therefore, I use the following piece of code to calculate the test set summed squared error and mean squared error.
epoch_mse_loss_tst = 0 # Mean Squared Error, initially zero
epoch_sse_loss_tst = 0 # Sum Squared Error, initially zero
for j,(x,t) in enumerate(tst_loader): # Run over the test set
x = x.to(device)
z = torch.zeros(batch_size, nz, 1, 1).to(device).requires_grad_()
g = F(z)
L_inner = ((g - x)**2).sum(1).mean()
grad = torch.autograd.grad(L_inner, [z], create_graph=True, retain_graph=True)[0]
z = (-grad)
g = F(z)
L_outer = ((g - x)**2).sum(1).mean() # Calculate batch mean squared error
sse = ((g - x)**2).sum() / x.shape[0] # Calculate batch summed squared error
epoch_mse_loss_tst += L_outer.item() # Update epoch MSE Loss
epoch_sse_loss_tst += sse.item() # Update epoch SSE Loss
for w in F.parameters(): # Make all the grads 0 after inference (since they do not contribute to training)
w.grad.data.zero_()
epoch_sse_loss_tst = epoch_sse_loss_tst/j
epoch_mse_loss_tst = epoch_mse_loss_tst/j
I hope it is helpful.