pyro
pyro copied to clipboard
GPU support for normalizing flows
I am interested in using the Pyro implementation of normalizing flows for my research. However, I cannot find anywhere in the docs instructions on how to enable GPU support. The example page makes no mention of GPUs, and if I modify that code from
dataset = torch.tensor(X, dtype=torch.float)
...
spline_transform = T.spline_coupling(2, count_bins=16)
to
dataset = torch.tensor(X, dtype=torch.float).to(device).to('cuda')
...
spline_transform = T.spline_coupling(2, count_bins=16).to('cuda')
and attempt to run, I get this error RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!. How should I move the TransformedDistributions class entirely onto the GPU device? Tagging @stefanwebb since he is the person responsible for normalizing flow support according to this issue.
As a quick workaround, you could set cuda as default before creating your spline_coupling object
torch.set_default_tensor_type(torch.cuda.FloatTensor)
spline_transform = T.spline_coupling(...)
Another not-ideal workaround is to save-and-load with map_location:
torch.save(any_python_object, "temp_file.pt")
any_python_object = torch.load("temp_file", map_location="cuda:0")
Hi @fritzo , I tried the first solution (torch.set_default_tensor_type(torch.cuda.FloatTensor)) and it doesn't seem to work. I get this error when using torch.DataLoaders and pytroch_lightning.LightningDataModule
RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
Is there some other workaround to try out? Also, what's the bottleneck here? Looking at the traceback it fails at log_prob() evaluation.
Details
200 z = posterior.sample()
201 log_qzx = posterior.log_prob(z)
--> 202 log_pz = prior_trans.log_prob(z)
203
204 kl = log_pz - log_qzx.sum(-1)
~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/distributions/transformed_distribution.py in log_prob(self, value)
141 y = value
142 for transform in reversed(self.transforms):
--> 143 x = transform.inv(y)
144 event_dim += transform.domain.event_dim - transform.codomain.event_dim
145 log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y),
~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/distributions/transforms.py in __call__(self, x)
340 def __call__(self, x):
341 for part in self.parts:
--> 342 x = part(x)
343 return x
344
~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/distributions/transforms.py in __call__(self, x)
247 def __call__(self, x):
248 assert self._inv is not None
--> 249 return self._inv._inv_call(x)
250
251 def log_abs_det_jacobian(self, x, y):
~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/distributions/transforms.py in _inv_call(self, y)
159 if y is y_old:
160 return x_old
--> 161 x = self._inverse(y)
162 self._cached_x_y = x, y
163 return x
~/miniconda3/envs/torch/lib/python3.8/site-packages/pyro/distributions/transforms/permute.py in _inverse(self, y)
91 Inverts y => x.
92 """
---> 93 return y.index_select(self.dim, self.inv_permutation)
94
95 def log_abs_det_jacobian(self, x, y):
Hi @gshartnett, it's difficult to diagnose where the stray cpu tensor is coming from without actually diving into a debugger. I'd recommend running under pdb and inspecting each tensor's device. You might also try updating PyTorch, since they seem to be getting more consistent over time.
One thing you could try is to use torch.save and torch.load, but you'd only be able to do that once before training:
x = my_complex_data_structure()
torch.save(x, "temp.pt")
x = torch.load("temp.pt", map_location="cuda:0")
thanks for the very prompt reply @fritzo , I was able to send some flows to cuda with self.to(device).
E.g.
AffineAutoregressive(
AutoRegressiveNN(
latent_dim,
[hidden_units for _ in range(n_hidden)],
skip_connections=True,
)
).to(device)
Amongst the one I tried, this also works for BatchNorm and AffineCoupling but it does not work for Permute (which returns an error like "permute does not have to() method"). From very high level inspection, Permute is the only one inheriting from torch.distributions.transforms.Transform instead of pyro.distributions.torch_transform.TransformModule amongst the one I tried (the latter inheriting from nn.Module). Might be off track, but anyway wanted to report this.
EDIT:
I was wrong, subclassing with TransformModule doesn't solve the problem. It seems to be in
return y.index_select(self.dim, self.inv_permutation)
SOLUTION: simply send to device the permutation index
Permute(torch.LongTensor(perm).to(device))
thanks again for the help!