TimeEval-algorithms
TimeEval-algorithms copied to clipboard
(lstm_vae) Loss function arguments doesn't match in training
Thank you for this wonderful and extensive research and by sharing it publicly. However I have encountered an issue in calculating the loss function of the LSTM VAE.
Here you wrote
loss = self.loss_function(x, logvar, mu, logvar, 'mean')
meanwhile in here:
def loss_function(self, x, x_hat, mean, log_var, reduction_type):
the input arguments doesnt match up or is this intended?
Thanks for checking out our research and code.
Do I understand you correctly that you are confused because we use logvar twice as an argument to the loss function? The number of arguments matches and most argument names match as well:
self.loss_function(
x=x,
x_hat=logvar,
mean=mu,
log_var=logvar,
reduction_type='mean'
)
Yes.
In taining, the loss function takes these arguments
loss = self.loss_function(x, logvar, mu, logvar, 'mean')
The loss function is defined as:
def loss_function(self, x, x_hat, mean, log_var, reduction_type):
reproduction_loss = nn.functional.mse_loss(x_hat, x, reduction=reduction_type)
KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
return KLD + reproduction_loss
The MSE is then calculated as reconstruction error of x and x_hat, where logvar is passed as the argument for that?
Please correct me if Im wrong. Thank you
Thank you for the clarification.
I am not familiar enough with this implementation to judge this. @wenig can you take a look?