GON icon indicating copy to clipboard operation
GON copied to clipboard

Where can I obtain the validation reconstruction loss?

Open Nianzhen-GU opened this issue 3 years ago • 1 comments

Hi,

In the paper, I see there is a comparison between different models. Where can I find the corresponding code that calculates the validation reconstruction loss? Thank you! Screen Shot 2021-12-06 at 11 00 52 AM

Nianzhen-GU avatar Dec 06 '21 16:12 Nianzhen-GU

I am not expert on this model but I will make my own comment. I think you can calculate it with the following line of code

((g - x)**2).sum() / x.shape[0]

where x is the ground truth batch data and g is the output of the model (after feeding the negative gradient as the latent). Therefore, I use the following piece of code to calculate the test set summed squared error and mean squared error.

epoch_mse_loss_tst = 0 # Mean Squared Error, initially zero
epoch_sse_loss_tst = 0 # Sum Squared Error, initially zero
for j,(x,t) in enumerate(tst_loader): # Run over the test set
    x = x.to(device)
    z = torch.zeros(batch_size, nz, 1, 1).to(device).requires_grad_()
    g = F(z)
    L_inner = ((g - x)**2).sum(1).mean()
    grad = torch.autograd.grad(L_inner, [z], create_graph=True, retain_graph=True)[0]
    z = (-grad)

    g = F(z)
    L_outer = ((g - x)**2).sum(1).mean() # Calculate batch mean squared error
    sse = ((g - x)**2).sum() / x.shape[0] # Calculate batch summed squared error
    epoch_mse_loss_tst += L_outer.item() # Update epoch MSE Loss
    epoch_sse_loss_tst += sse.item() # Update epoch SSE Loss
    for w in F.parameters(): # Make all the grads 0 after inference (since they do not contribute to training)
        w.grad.data.zero_()

epoch_sse_loss_tst = epoch_sse_loss_tst/j
epoch_mse_loss_tst = epoch_mse_loss_tst/j

I hope it is helpful.

BariscanBozkurt avatar Dec 14 '21 09:12 BariscanBozkurt