pytorch-domain-adaptation
pytorch-domain-adaptation copied to clipboard
Question wdgrl.py - gradient_penalty(critic, h_s, h_t)
In wdgrl.py - gradient_penalty(critic, h_s, h_t): The interpolates created in line 29 are of size (3 x batch_size x feature_size). Following that the gradients are also of size (3 x batch_size x feature_size). When calculating gradients.norm(2, dim=1) in line 35 dimension 1 therefore refers to the batch_size dimension. Is this correct? Intuitively I would have done that across the feature dimension.
I think the code is wrong. In my opinion, the code in Line 29 should be replaced with "interpolates = torch.cat([interpolates, h_s, h_t]).requires_grad_()"