blogposts icon indicating copy to clipboard operation
blogposts copied to clipboard

How can I make your VAE work for my own custom dataset?

Open monajalal opened this issue 4 years ago • 1 comments

Here's the error I get:

[---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-15-37f467c4f834> in <module>
      1 for epoch in range(1, epochs + 1):
----> 2     train(epoch)
      3     test(epoch)
      4     with torch.no_grad():
      5         sample = torch.randn(2, 2048).to(device)

<ipython-input-13-8f191bde6513> in train(epoch)
      6         optimizer.zero_grad()
      7         recon_batch, mu, logvar = model(data)
----> 8         loss = loss_mse(recon_batch, data, mu, logvar)
      9         loss.backward()
     10         train_loss += loss.item()

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

<ipython-input-9-6c49edf3f96a> in forward(self, x_recon, x, mu, logvar)
      5 
      6     def forward(self, x_recon, x, mu, logvar):
----> 7         loss_MSE = self.mse_loss(x_recon, x)
      8         loss_KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
      9 

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
    443 
    444     def forward(self, input: Tensor, target: Tensor) -> Tensor:
--> 445         return F.mse_loss(input, target, reduction=self.reduction)
    446 
    447 

~/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py in mse_loss(input, target, size_average, reduce, reduction)
   2645             ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
   2646     else:
-> 2647         expanded_input, expanded_target = torch.broadcast_tensors(input, target)
   2648         ret = torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
   2649     return ret

~/anaconda3/lib/python3.7/site-packages/torch/functional.py in broadcast_tensors(*tensors)
     63         if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
     64             return handle_torch_function(broadcast_tensors, tensors, *tensors)
---> 65     return _VF.broadcast_tensors(tensors)
     66 
     67 

RuntimeError: The size of tensor a (100) must match the size of tensor b (800) at non-singleton dimension 3

My images are of dimension 600x800.

monajalal avatar Nov 11 '20 22:11 monajalal

I used the following

my_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((100,100))
                      ])

train_loader_food = torch.utils.data.DataLoader(
    datasets.ImageFolder(train_root, transform = my_transform),
    batch_size = batch_size, shuffle=True, **kwargs)

val_loader_food = torch.utils.data.DataLoader(
    datasets.ImageFolder(val_root, transform = my_transform),
    batch_size = batch_size, shuffle=True, **kwargs)

Now I am getting this error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-15-37f467c4f834> in <module>
      1 for epoch in range(1, epochs + 1):
----> 2     train(epoch)
      3     test(epoch)
      4     with torch.no_grad():
      5         sample = torch.randn(2, 2048).to(device)

<ipython-input-13-8f191bde6513> in train(epoch)
      2     model.train()
      3     train_loss = 0
----> 4     for batch_idx, (data, _) in enumerate(train_loader_food):
      5         data = data.to(device)
      6         optimizer.zero_grad()

~/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
    361 
    362     def __next__(self):
--> 363         data = self._next_data()
    364         self._num_yielded += 1
    365         if self._dataset_kind == _DatasetKind.Iterable and \

~/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    987             else:
    988                 del self._task_info[idx]
--> 989                 return self._process_data(data)
    990 
    991     def _try_put_index(self):

~/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)
   1012         self._try_put_index()
   1013         if isinstance(data, ExceptionWrapper):
-> 1014             data.reraise()
   1015         return data
   1016 

~/anaconda3/lib/python3.7/site-packages/torch/_utils.py in reraise(self)
    393             # (https://bugs.python.org/issue2651), so we work around it.
    394             msg = KeyErrorMessage(msg)
--> 395         raise self.exc_type(msg)

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/mona/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 185, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/mona/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/mona/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/mona/anaconda3/lib/python3.7/site-packages/torchvision/datasets/folder.py", line 139, in __getitem__
    sample = self.transform(sample)
  File "/home/mona/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 61, in __call__
    img = t(img)
  File "/home/mona/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 244, in __call__
    return F.resize(img, self.size, self.interpolation)
  File "/home/mona/anaconda3/lib/python3.7/site-packages/torchvision/transforms/functional.py", line 319, in resize
    raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
TypeError: img should be PIL Image. Got <class 'torch.Tensor'>

monajalal avatar Nov 11 '20 23:11 monajalal