stargan-v2
stargan-v2 copied to clipboard
Switch to grayscale
Hi,
I'm trying to use StarGAN v2 with gray-scale image. So I get this type of error :
Traceback (most recent call last):
File "main.py", line 182, in <module>
main(args)
File "main.py", line 59, in main
solver.train(loaders)
File "/home/yiheng/bacasable/stargan-v2/core/solver.py", line 110, in train
nets, args, x_real, y_org, y_trg, z_trg=z_trg, masks=masks)
File "/home/yiheng/bacasable/stargan-v2/core/solver.py", line 205, in compute_d_loss
out = nets.discriminator(x_real, y_org)
File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/home/yiheng/bacasable/stargan-v2/core/model.py", line 275, in forward
out = self.main(x)
File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/container.py", line 100, in forward
input = module(input)
File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 345, in forward
return self.conv2d_forward(input, self.weight)
File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 342, in conv2d_forward
self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size 64 3 3 3, expected input[8, 1, 256, 256] to have 3 channels, but got 1 channels instead
I think that it is because of the number of input channel, because gray-scale have only 1 channel rather than 3 for RGB ones. But I don't know where I should adapt parameters in your code.
If you could give me an indication it would be nice !
Thank you !
Replace all nn.Conv2d(3, dim_in, 3, 1, 1)
in model.py
with nn.Conv2d(1, dim_in, 3, 1, 1)
Thank you @Jonas1312
But I still got the error but it changed a little bit ...
Traceback (most recent call last):
File "main.py", line 182, in <module>
main(args)
File "main.py", line 59, in main
solver.train(loaders)
File "/home/yiheng/bacasable/stargan-v2/core/solver.py", line 110, in train
nets, args, x_real, y_org, y_trg, z_trg=z_trg, masks=masks)
File "/home/yiheng/bacasable/stargan-v2/core/solver.py", line 217, in compute_d_loss
out = nets.discriminator(x_fake, y_trg)
File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/home/yiheng/bacasable/stargan-v2/core/model.py", line 275, in forward
out = self.main(x)
File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/container.py", line 100, in forward
input = module(input)
File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 345, in forward
return self.conv2d_forward(input, self.weight)
File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 342, in conv2d_forward
self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size 64 1 3 3, expected input[1, 3, 256, 256] to have 1 channels, but got 3 channels instead
Maybe there is some other spots to change to 1 ?
RuntimeError: Given groups=1, weight of size 64 1 3 3, expected input[1, 3, 256, 256] to have 1 channels, but got 3 channels instead
add print(x.size())
before out = self.main(x)
. It should print (B, 1, 256, 256)
where B
is the number of images in your batch
Thank you again for your help ! I printed the size before as you indicated:
with torch.no_grad():
if z_trg is not None:
s_trg = nets.mapping_network(z_trg, y_trg)
else: # x_ref is not None
s_trg = nets.style_encoder(x_ref, y_trg)
print("x_real:", x_real.size(),"s_trg:", s_trg.size())
x_fake = nets.generator(x_real, s_trg, masks=masks)
print("x_fake:", x_fake.size(),"y_trg:", y_trg.size())
The problem was with the generator:
x_real: torch.Size([1, 1, 256, 256]) s_trg: torch.Size([1, 64])
x_fake: torch.Size([1, 3, 256, 256]) y_trg: torch.Size([1])
So I modify the line below in model.py
in the Generator
class to avoid this:
self.to_rgb = nn.Sequential(
nn.InstanceNorm2d(dim_in, affine=True),
nn.LeakyReLU(0.2),
nn.Conv2d(dim_in, 1, 1, 1, 0)) # Before : nn.Conv2d(dim_in, 3, 1, 1,0)
Thank you again @Jonas1312 !
Hello, I have the same problem as you. What needs to be changed in the code? I don't understand.