wassdistance icon indicating copy to clipboard operation
wassdistance copied to clipboard

Added device option to Wasserstein

Open sAbhay opened this issue 3 years ago • 1 comments

Updated wasserstein submodule to include device

sAbhay avatar Jun 20 '22 23:06 sAbhay

@sAbhay thank you for your contribution! Sorry for the long delay on my response. To keep compatibility with distributed training, where computations could run on different devices, I think it would be better to grab the device during the forward pass, rather than fixing it during initialization. For example,

def forward(self, x, y):
    device = x.device     
    ...

    mu = torch.empty(batch_size, x_points, dtype=torch.float,
                     requires_grad=False).fill_(1.0 / x_points).squeeze()
                     requires_grad=False, device=device).fill_(1.0 / x_points).squeeze()
    ...

This way, computations will run in whatever device x might be. What do you think?

dfdazac avatar Aug 25 '22 09:08 dfdazac