SWEP
SWEP copied to clipboard
Some doubts for total loss?
Sorry to bother you. I found that your code is different from your paper in the loss calculation process. Why is the final loss NLL + beta * KL? Your paper says KL is a part of noise
data:image/s3,"s3://crabby-images/c3444/c34448ab640efe4b770e29bae8aabdcf03279114" alt="image"
So the implementation is consistent with our equation.
![]()
Gradient of KL with respect to \theta (the parameter of BERT) is zero since we do not backpropagate it to BERT (torch.no_grad() in line 42. So the implementation is consistent with our equation. Thanks for your reply !