ferattention
ferattention copied to clipboard
the variable "loss_att" is overrided
in the class AttMSEloss, in the forward function.
loss_att = ((( (x_org*y_mask[:,1,...].unsqueeze(dim=1)) - att ) ** 2)).mean()
this result is overrided by the next sentence,
loss_att = ((( x_org - att ) ** 2)).mean()
is this a bug?