pytorch-domain-adaptation icon indicating copy to clipboard operation
pytorch-domain-adaptation copied to clipboard

Question wdgrl.py - gradient_penalty(critic, h_s, h_t)

Open fschmid56 opened this issue 4 years ago • 1 comments

In wdgrl.py - gradient_penalty(critic, h_s, h_t): The interpolates created in line 29 are of size (3 x batch_size x feature_size). Following that the gradients are also of size (3 x batch_size x feature_size). When calculating gradients.norm(2, dim=1) in line 35 dimension 1 therefore refers to the batch_size dimension. Is this correct? Intuitively I would have done that across the feature dimension.

fschmid56 avatar Aug 13 '21 13:08 fschmid56

I think the code is wrong. In my opinion, the code in Line 29 should be replaced with "interpolates = torch.cat([interpolates, h_s, h_t]).requires_grad_()"

shilianghe007 avatar Nov 20 '22 12:11 shilianghe007