deep-image-prior
deep-image-prior copied to clipboard
tensor concatenate error in sr_prior_effect notebook
For me the very last line of the sr_prior_effect notebook fails, but I'm not sure why. It gives an incorrect type error even through they are the same type (I think...)
plot_image_grid([imgs['HR_np'],
result_no_prior,
result_tv_prior,
result_deep_prior], factor=8, nrow=2, interpolation='lanczos')
-----------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-42-8fd807e1bf2e> in <module>()
2 result_no_prior,
3 result_tv_prior,
----> 4 result_deep_prior], factor=8, nrow=2, interpolation='lanczos')
~/DeepImagePrior/deep-image-prior/utils/common_utils.py in plot_image_grid(images_np, nrow, factor, interpolation)
75 images_np = [x if (x.shape[0] == n_channels) else np.concatenate([x, x, x], axis=0) for x in images_np]
76
---> 77 grid = get_image_grid(images_np, nrow)
78
79 plt.figure(figsize=(len(images_np)+factor,12+factor))
~/DeepImagePrior/deep-image-prior/utils/common_utils.py in get_image_grid(images_np, nrow)
57 '''Creates a grid from a list of images by concatenating them.'''
58 images_torch = [torch.from_numpy(x) for x in images_np]
---> 59 torch_grid = torchvision.utils.make_grid(images_torch, nrow)
60
61 return torch_grid.numpy()
~/anaconda2/envs/py36/lib/python3.6/site-packages/torchvision-0.2.0-py3.6.egg/torchvision/utils.py in make_grid(tensor, nrow, padding, normalize, range, scale_each, pad_value)
33 # if list of tensors, convert to a 4D mini-batch Tensor
34 if isinstance(tensor, list):
---> 35 tensor = torch.stack(tensor, dim=0)
36
37 if tensor.dim() == 2: # single image H x W
~/anaconda2/envs/py36/lib/python3.6/site-packages/torch/functional.py in stack(sequence, dim, out)
62 inputs = [t.unsqueeze(dim) for t in sequence]
63 if out is None:
---> 64 return torch.cat(inputs, dim)
65 else:
66 return torch.cat(inputs, dim, out=out)
TypeError: cat received an invalid combination of arguments - got (list, int), but expected one of:
* (sequence[torch.FloatTensor] seq)
* (sequence[torch.FloatTensor] seq, int dim)
didn't match because some of the arguments have invalid types: (list, int)
Seems like torch doesn't like ndarrays, and expects torch tensors. Correct usage is:
plot_image_grid([imgs['HR_np'],
out_HR_noprior_np,
out_HR_TV_np,
out_HR_deep_np], factor=8, nrow=2, interpolation='lanczos');
Where out_X_np = np.clip(var_to_np(net(net_input)), 0, 1)
after each loss experiment in the notebook