wassdistance
wassdistance copied to clipboard
Added device option to Wasserstein
Updated wasserstein submodule to include device
@sAbhay thank you for your contribution! Sorry for the long delay on my response. To keep compatibility with distributed training, where computations could run on different devices, I think it would be better to grab the device during the forward pass, rather than fixing it during initialization. For example,
def forward(self, x, y):
device = x.device
...
mu = torch.empty(batch_size, x_points, dtype=torch.float,
requires_grad=False).fill_(1.0 / x_points).squeeze()
requires_grad=False, device=device).fill_(1.0 / x_points).squeeze()
...
This way, computations will run in whatever device x might be. What do you think?