TimeEval-algorithms icon indicating copy to clipboard operation
TimeEval-algorithms copied to clipboard

(lstm_vae) Loss function arguments doesn't match in training

Open Vhunon opened this issue 2 years ago • 3 comments

Thank you for this wonderful and extensive research and by sharing it publicly. However I have encountered an issue in calculating the loss function of the LSTM VAE.

Here you wrote loss = self.loss_function(x, logvar, mu, logvar, 'mean')

meanwhile in here: def loss_function(self, x, x_hat, mean, log_var, reduction_type): the input arguments doesnt match up or is this intended?

Vhunon avatar Jul 27 '23 18:07 Vhunon

Thanks for checking out our research and code.

Do I understand you correctly that you are confused because we use logvar twice as an argument to the loss function? The number of arguments matches and most argument names match as well:

self.loss_function(
  x=x,
  x_hat=logvar,
  mean=mu,
  log_var=logvar,
  reduction_type='mean'
)

SebastianSchmidl avatar Jul 28 '23 07:07 SebastianSchmidl

Yes.

In taining, the loss function takes these arguments loss = self.loss_function(x, logvar, mu, logvar, 'mean')

The loss function is defined as:

def loss_function(self, x, x_hat, mean, log_var, reduction_type):
      reproduction_loss = nn.functional.mse_loss(x_hat, x, reduction=reduction_type)
      KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
      return KLD + reproduction_loss

The MSE is then calculated as reconstruction error of x and x_hat, where logvar is passed as the argument for that?

Please correct me if Im wrong. Thank you

Vhunon avatar Jul 28 '23 12:07 Vhunon

Thank you for the clarification.

I am not familiar enough with this implementation to judge this. @wenig can you take a look?

SebastianSchmidl avatar Jul 28 '23 14:07 SebastianSchmidl