pyro icon indicating copy to clipboard operation
pyro copied to clipboard

GPU support for normalizing flows

Open gshartnett opened this issue 4 years ago • 4 comments

I am interested in using the Pyro implementation of normalizing flows for my research. However, I cannot find anywhere in the docs instructions on how to enable GPU support. The example page makes no mention of GPUs, and if I modify that code from

dataset = torch.tensor(X, dtype=torch.float)
...
spline_transform = T.spline_coupling(2, count_bins=16)

to

dataset = torch.tensor(X, dtype=torch.float).to(device).to('cuda')
...
spline_transform = T.spline_coupling(2, count_bins=16).to('cuda')

and attempt to run, I get this error RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!. How should I move the TransformedDistributions class entirely onto the GPU device? Tagging @stefanwebb since he is the person responsible for normalizing flow support according to this issue.

gshartnett avatar Feb 05 '21 17:02 gshartnett

As a quick workaround, you could set cuda as default before creating your spline_coupling object

torch.set_default_tensor_type(torch.cuda.FloatTensor)
spline_transform = T.spline_coupling(...)

Another not-ideal workaround is to save-and-load with map_location:

torch.save(any_python_object, "temp_file.pt")
any_python_object = torch.load("temp_file", map_location="cuda:0")

fritzo avatar Feb 05 '21 17:02 fritzo

Hi @fritzo , I tried the first solution (torch.set_default_tensor_type(torch.cuda.FloatTensor)) and it doesn't seem to work. I get this error when using torch.DataLoaders and pytroch_lightning.LightningDataModule

RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'

Is there some other workaround to try out? Also, what's the bottleneck here? Looking at the traceback it fails at log_prob() evaluation.

Details
    200         z = posterior.sample()
    201         log_qzx = posterior.log_prob(z)
--> 202         log_pz = prior_trans.log_prob(z)
    203 
    204         kl = log_pz - log_qzx.sum(-1)

~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/distributions/transformed_distribution.py in log_prob(self, value)
    141         y = value
    142         for transform in reversed(self.transforms):
--> 143             x = transform.inv(y)
    144             event_dim += transform.domain.event_dim - transform.codomain.event_dim
    145             log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y),

~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/distributions/transforms.py in __call__(self, x)
    340     def __call__(self, x):
    341         for part in self.parts:
--> 342             x = part(x)
    343         return x
    344 

~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/distributions/transforms.py in __call__(self, x)
    247     def __call__(self, x):
    248         assert self._inv is not None
--> 249         return self._inv._inv_call(x)
    250 
    251     def log_abs_det_jacobian(self, x, y):

~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/distributions/transforms.py in _inv_call(self, y)
    159         if y is y_old:
    160             return x_old
--> 161         x = self._inverse(y)
    162         self._cached_x_y = x, y
    163         return x

~/miniconda3/envs/torch/lib/python3.8/site-packages/pyro/distributions/transforms/permute.py in _inverse(self, y)
     91         Inverts y => x.
     92         """
---> 93         return y.index_select(self.dim, self.inv_permutation)
     94 
     95     def log_abs_det_jacobian(self, x, y):

giovp avatar Jun 23 '21 20:06 giovp

Hi @gshartnett, it's difficult to diagnose where the stray cpu tensor is coming from without actually diving into a debugger. I'd recommend running under pdb and inspecting each tensor's device. You might also try updating PyTorch, since they seem to be getting more consistent over time.

One thing you could try is to use torch.save and torch.load, but you'd only be able to do that once before training:

x = my_complex_data_structure()
torch.save(x, "temp.pt")
x = torch.load("temp.pt", map_location="cuda:0")

fritzo avatar Jun 23 '21 21:06 fritzo

thanks for the very prompt reply @fritzo , I was able to send some flows to cuda with self.to(device).

E.g.

AffineAutoregressive(
	AutoRegressiveNN(
		latent_dim,
		[hidden_units for _ in range(n_hidden)],
		skip_connections=True,
	)
).to(device)

Amongst the one I tried, this also works for BatchNorm and AffineCoupling but it does not work for Permute (which returns an error like "permute does not have to() method"). From very high level inspection, Permute is the only one inheriting from torch.distributions.transforms.Transform instead of pyro.distributions.torch_transform.TransformModule amongst the one I tried (the latter inheriting from nn.Module). Might be off track, but anyway wanted to report this.

EDIT: I was wrong, subclassing with TransformModule doesn't solve the problem. It seems to be in

return y.index_select(self.dim, self.inv_permutation)

SOLUTION: simply send to device the permutation index

Permute(torch.LongTensor(perm).to(device))

thanks again for the help!

giovp avatar Jun 27 '21 17:06 giovp