normalizing-flows icon indicating copy to clipboard operation
normalizing-flows copied to clipboard

Add GPU support for model and flows

Open Baukebrenninkmeijer opened this issue 5 years ago • 4 comments

I added GPU support for all forward/inverse passes of the flows that support it and NormalizingFlowModel. The tensors are moved to the same device as the input tensor (so wherever x or z are). For now, the result of sampling will be moved to the CPU, since I expect that to most often be the use-case. Let me know if you have any other suggestions.

To get it working, put your x, prior and model on the GPU:

device = 'cuda' if torch.cuda.is_available() else 'cpu'
prior = MultivariateNormal(torch.zeros(1).to(device), torch.eye(1).to(device))
model = NormalizingFlowModel(prior, flows).to(device)
x = torch.Tensor(gen_data(args.n)).to(device)

From there on, it should work without any problems.

I didn't change the examples, but can make them use GPU is possible as well. Let me know your preference.

Baukebrenninkmeijer avatar Apr 22 '20 09:04 Baukebrenninkmeijer

Moves device location to tensor initialization instead of after. Removed the cpu as default location for sampling, and is now taking the same location as the prior. Should be good now :).

Baukebrenninkmeijer avatar Apr 23 '20 12:04 Baukebrenninkmeijer

The parameters in OneByOneConv are not registered correctly as parameters, and are not moved to the correct device when we call model.to(device). So now i'm calling the .to in the forward and backward call individually.

Baukebrenninkmeijer avatar Apr 23 '20 13:04 Baukebrenninkmeijer

Thanks for making the requested changes.

The parameters in OneByOneConv are not registered correctly as parameters, and are not moved to the correct device when we call model.to(device). So now i'm calling the .to in the forward and backward call individually.

I think we can fix this by replacing the line:

self.P = torch.tensor(P, dtype = torch.float)

with:

self.P = nn.Parameter(torch.tensor(P, dtype = torch.float), requires_grad = False)

Could you take a stab at this and let me know if this fixes the issue?

tonyduan avatar Apr 23 '20 19:04 tonyduan

Yes, i'll have a look when I have time. Hopefully, somewhere later this week.

Baukebrenninkmeijer avatar Apr 28 '20 12:04 Baukebrenninkmeijer