stargan-v2 icon indicating copy to clipboard operation
stargan-v2 copied to clipboard

Switch to grayscale

Open cyiheng opened this issue 4 years ago • 5 comments

Hi,

I'm trying to use StarGAN v2 with gray-scale image. So I get this type of error :

Traceback (most recent call last):
  File "main.py", line 182, in <module>
    main(args)
  File "main.py", line 59, in main
    solver.train(loaders)
  File "/home/yiheng/bacasable/stargan-v2/core/solver.py", line 110, in train
    nets, args, x_real, y_org, y_trg, z_trg=z_trg, masks=masks)
  File "/home/yiheng/bacasable/stargan-v2/core/solver.py", line 205, in compute_d_loss
    out = nets.discriminator(x_real, y_org)
  File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/yiheng/bacasable/stargan-v2/core/model.py", line 275, in forward
    out = self.main(x)
  File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/container.py", line 100, in forward
    input = module(input)
  File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 345, in forward
    return self.conv2d_forward(input, self.weight)
  File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 342, in conv2d_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size 64 3 3 3, expected input[8, 1, 256, 256] to have 3 channels, but got 1 channels instead

I think that it is because of the number of input channel, because gray-scale have only 1 channel rather than 3 for RGB ones. But I don't know where I should adapt parameters in your code.

If you could give me an indication it would be nice !

Thank you !

cyiheng avatar Nov 24 '20 16:11 cyiheng

Replace all nn.Conv2d(3, dim_in, 3, 1, 1) in model.py with nn.Conv2d(1, dim_in, 3, 1, 1)

Jonas1312 avatar Nov 24 '20 17:11 Jonas1312

Thank you @Jonas1312
But I still got the error but it changed a little bit ...

Traceback (most recent call last):
  File "main.py", line 182, in <module>
    main(args)
  File "main.py", line 59, in main
    solver.train(loaders)
  File "/home/yiheng/bacasable/stargan-v2/core/solver.py", line 110, in train
    nets, args, x_real, y_org, y_trg, z_trg=z_trg, masks=masks)
  File "/home/yiheng/bacasable/stargan-v2/core/solver.py", line 217, in compute_d_loss
    out = nets.discriminator(x_fake, y_trg)
  File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/yiheng/bacasable/stargan-v2/core/model.py", line 275, in forward
    out = self.main(x)
  File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/container.py", line 100, in forward
    input = module(input)
  File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 345, in forward
    return self.conv2d_forward(input, self.weight)
  File "/home/yiheng/anaconda3/envs/stargan-v2/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 342, in conv2d_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size 64 1 3 3, expected input[1, 3, 256, 256] to have 1 channels, but got 3 channels instead

Maybe there is some other spots to change to 1 ?

cyiheng avatar Nov 24 '20 17:11 cyiheng

RuntimeError: Given groups=1, weight of size 64 1 3 3, expected input[1, 3, 256, 256] to have 1 channels, but got 3 channels instead

add print(x.size()) before out = self.main(x). It should print (B, 1, 256, 256) where B is the number of images in your batch

Jonas1312 avatar Nov 24 '20 17:11 Jonas1312

Thank you again for your help ! I printed the size before as you indicated:

    with torch.no_grad():
        if z_trg is not None:
            s_trg = nets.mapping_network(z_trg, y_trg)
        else:  # x_ref is not None
            s_trg = nets.style_encoder(x_ref, y_trg)
        print("x_real:", x_real.size(),"s_trg:", s_trg.size())
        x_fake = nets.generator(x_real, s_trg, masks=masks)
    print("x_fake:", x_fake.size(),"y_trg:", y_trg.size())

The problem was with the generator:

x_real: torch.Size([1, 1, 256, 256]) s_trg: torch.Size([1, 64])
x_fake: torch.Size([1, 3, 256, 256]) y_trg: torch.Size([1])

So I modify the line below in model.py in the Generator class to avoid this:

self.to_rgb = nn.Sequential(
            nn.InstanceNorm2d(dim_in, affine=True),
            nn.LeakyReLU(0.2),
            nn.Conv2d(dim_in, 1, 1, 1, 0)) # Before : nn.Conv2d(dim_in, 3, 1, 1,0)

Thank you again @Jonas1312 !

cyiheng avatar Nov 24 '20 20:11 cyiheng

Hello, I have the same problem as you. What needs to be changed in the code? I don't understand.

typeface-cn avatar Mar 14 '23 14:03 typeface-cn