cp-vton-plus
cp-vton-plus copied to clipboard
Error while training GMM
Hi! I was trying to run the following command:
python3 train.py --name GMM --stage GMM --workers 4 --save_count 5000 --shuffle
When the interpreter gave me this error:
amnesy@efe:~/Desktop/cp-vton-plus$ python3 train.py --name GMM --stage GMM --workers 4 --save_count 5000 --shuffle
Namespace(batch_size=4, checkpoint='', checkpoint_dir='checkpoints', data_list='train_pairs.txt', datamode='train', dataroot='data', decay_step=100000, display_count=20, fine_height=256, fine_width=192, gpu_ids='', grid_size=5, keep_step=100000, lr=0.0001, name='GMM', radius=5, save_count=5000, shuffle=True, stage='GMM', tensorboard_dir='tensorboard', workers=4)
Start to train stage: GMM, named: GMM!
initialization method [normal]
initialization method [normal]
Traceback (most recent call last):
File "train.py", line 229, in <module>
main()
File "train.py", line 210, in main
train_gmm(opt, train_loader, model, board)
File "train.py", line 72, in train_gmm
inputs = train_loader.next_batch()
File "/home/amnesy/Desktop/cp-vton-plus/cp_dataset.py", line 228, in next_batch
batch = self.data_iter.__next__()
File "/home/amnesy/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 517, in __next__
data = self._next_data()
File "/home/amnesy/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1199, in _next_data
return self._process_data(data)
File "/home/amnesy/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1225, in _process_data
data.reraise()
File "/home/amnesy/.local/lib/python3.8/site-packages/torch/_utils.py", line 429, in reraise
raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/amnesy/.local/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 202, in _worker_loop
data = fetcher.fetch(index)
File "/home/amnesy/.local/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/amnesy/.local/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/amnesy/Desktop/cp-vton-plus/cp_dataset.py", line 138, in __getitem__
shape_ori = self.transform(parse_shape_ori) # [-1,1]
File "/home/amnesy/.local/lib/python3.8/site-packages/torchvision/transforms/transforms.py", line 60, in __call__
img = t(img)
File "/home/amnesy/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/amnesy/.local/lib/python3.8/site-packages/torchvision/transforms/transforms.py", line 221, in forward
return F.normalize(tensor, self.mean, self.std, self.inplace)
File "/home/amnesy/.local/lib/python3.8/site-packages/torchvision/transforms/functional.py", line 336, in normalize
tensor.sub_(mean).div_(std)
RuntimeError: output with shape [1, 256, 192] doesn't match the broadcast shape [3, 256, 192]
What should i do? In issue #61 TidamCo wrote:
You probably have a greyscale image as an input, and the model is expecting RGB.
(RGB has a shape of 3, while greyscale has a shape of 1)
But I have no idea on how to specify the correct image. Note that I uncommented line 93 and commented line 94 as described in the issue #60, and I have unzipped the images in the data folder. Any ideas?
Hi @amnesytoolkit , this is most probably due to the wrong folder path of cloth mask or image-parse. Can you please check which folder is set in your cp_dataset.py file? Also maybe debugging a bit by changing to the other folder solve the issue. For example, if there is a image-parse-vis folder, try setting the path to that one, and see which one works. Hope that helps. Thanks.
I have cloned the repo as it is and i am sure the paths are right, but i still have the same error.
I ran into https://github.com/minar09/cp-vton-plus/issues/6 too
You'll need to update cp_dataset.py
transforms.Normalize((0.5,` 0.5, 0.5), (0.5, 0.5, 0.5))]) to transforms.Normalize((0.5), (0.5))