pixel2style2pixel icon indicating copy to clipboard operation
pixel2style2pixel copied to clipboard

Question about Encoder Model when using a different value for style_dim for Generator

Open aravind598 opened this issue 1 year ago • 1 comments

Hi, thank you for this excellent repo. I have unfortunately encountered an error when training a p2p StyleGAN inversion encoder using a custom StyleGAN model using the ResNetBackBoneEncoder as my default encoder type.

This is my error as shown below:

Traceback (most recent call last):
  File "scripts/train_restyle_psp.py", line 30, in <module>
    main()
  File "scripts/train_restyle_psp.py", line 26, in main
    coach.train()
  File ".\training\coach_restyle_psp.py", line 123, in train
    y_hats, loss_dict, id_logs = self.perform_train_iteration_on_batch(x, y)
  File ".\training\coach_restyle_psp.py", line 97, in perform_train_iteration_on_batch
    y_hat, latent = self.net.forward(x_input, latent=None, return_latents=True)
  File ".\models\psp.py", line 70, in forward
    codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
RuntimeError: The size of tensor a (512) must match the size of tensor b (1024) at non-singleton dimension 2

I have modified the rosinality pytorch code to allow for model inferencing and have modified the model code in the stylegan2 to account for this change. I have also changed the decoder line 22 in psp.py from:

self.decoder = Generator(self.opts.output_size, 512, 8, channel_multiplier=2)

to:

self.decoder = Generator(self.opts.output_size, 1024, 4, channel_multiplier=2)

Therefore, I believe the error above is caused by the difference in the style_dim in which the original unmodified stylegan2 has a style_dim of 512 while the other model has a style_dim of 1024 causing the encoder to return a tensor of [8,16,512] instead of [8,16,1024] which should be returned instead.

To fix this error I changed line 81 in restyle_psp_encoders.py to

style = GradualStyleBlock(512, 1024, 16)

However I am unable to test this yet due to CUDA running out of memory. Is my solution to the problem correct?

aravind598 avatar Aug 10 '22 17:08 aravind598

It seems like your changes are correct. The entire pSp repo is based on SG2 which has latent codes of size 512. If you changed your generator to output latents of size 1024, you need to make the changes to pSp, as you did. However, by doing so, you created a much larger encoder network, resulting in the out of memory. There are three options off the top of my head:

  1. Reduce the batch size.
  2. Simplify the encoder. To do this, you can try simplifying the GradualStyleBlock by using smaller layers for example. If you simplify the encoder enough, you may be able to avoid the out of memory on the original batch size.
  3. You could change architectures. One option is to use the lighter-weight architecture that I used in ReStyle that can be found here: https://github.com/yuval-alaluf/restyle-encoder

Hope this helps!

yuval-alaluf avatar Aug 12 '22 05:08 yuval-alaluf

Hi, much thanks for your prompt and helpful reply, as you described, I am currently using the ReStyle psp Encoder and have reduced the batch size to 2 as using a batch size of 1 results in this Traceback below. I think using a batch size of 1 results in y_hat_feats and y_feats being scalar values. However, I am still getting CUDA OOM even with 16GB of VRAM.

Traceback (most recent call last):
  File "scripts/train_restyle_psp.py", line 30, in <module>
    main()
  File "scripts/train_restyle_psp.py", line 26, in main
    coach.train()
  File ".\training\coach_restyle_psp.py", line 123, in train
    y_hats, loss_dict, id_logs = self.perform_train_iteration_on_batch(x, y)
  File ".\training\coach_restyle_psp.py", line 107, in perform_train_iteration_on_batch
    loss, loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent)
  File ".\training\coach_restyle_psp.py", line 265, in calc_loss
    loss_moco, sim_improvement, id_logs = self.moco_loss(y_hat, y, x)
  File "C:\Users\Aravind\anaconda3\envs\lit\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File ".\criteria\moco_loss.py", line 58, in forward
    diff_target = y_hat_feats[i].dot(y_feats[i])
RuntimeError: 1D tensors expected, but got 0D and 0D tensors

This is the Encoder network printed using torchinfo. The number of training parameters is exceedingly large. Is there any way to reduce the number of parameters involved or reduce the number of GradualStyleBlock layers and reduce the memory usage for CUDA? Again many thanks for this awesome repo and your help.

Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ResNetBackboneEncoder                    [8, 16, 1024]             545,341,440
├─Conv2d: 1-1                            [8, 64, 128, 128]         18,816
├─BatchNorm2d: 1-2                       [8, 64, 128, 128]         128
├─PReLU: 1-3                             [8, 64, 128, 128]         64
├─Sequential: 1-4                        [8, 512, 16, 16]          --
│    └─BasicBlock: 2-1                   [8, 64, 128, 128]         --
│    │    └─Conv2d: 3-1                  [8, 64, 128, 128]         36,864
│    │    └─BatchNorm2d: 3-2             [8, 64, 128, 128]         128
│    │    └─ReLU: 3-3                    [8, 64, 128, 128]         --
│    │    └─Conv2d: 3-4                  [8, 64, 128, 128]         36,864
│    │    └─BatchNorm2d: 3-5             [8, 64, 128, 128]         128
│    │    └─ReLU: 3-6                    [8, 64, 128, 128]         --
│    └─BasicBlock: 2-2                   [8, 64, 128, 128]         --
│    │    └─Conv2d: 3-7                  [8, 64, 128, 128]         36,864
│    │    └─BatchNorm2d: 3-8             [8, 64, 128, 128]         128
│    │    └─ReLU: 3-9                    [8, 64, 128, 128]         --
│    │    └─Conv2d: 3-10                 [8, 64, 128, 128]         36,864
│    │    └─BatchNorm2d: 3-11            [8, 64, 128, 128]         128
│    │    └─ReLU: 3-12                   [8, 64, 128, 128]         --
│    └─BasicBlock: 2-3                   [8, 64, 128, 128]         --
│    │    └─Conv2d: 3-13                 [8, 64, 128, 128]         36,864
│    │    └─BatchNorm2d: 3-14            [8, 64, 128, 128]         128
│    │    └─ReLU: 3-15                   [8, 64, 128, 128]         --
│    │    └─Conv2d: 3-16                 [8, 64, 128, 128]         36,864
│    │    └─BatchNorm2d: 3-17            [8, 64, 128, 128]         128
│    │    └─ReLU: 3-18                   [8, 64, 128, 128]         --
│    └─BasicBlock: 2-4                   [8, 128, 64, 64]          --
│    │    └─Conv2d: 3-19                 [8, 128, 64, 64]          73,728
│    │    └─BatchNorm2d: 3-20            [8, 128, 64, 64]          256
│    │    └─ReLU: 3-21                   [8, 128, 64, 64]          --
│    │    └─Conv2d: 3-22                 [8, 128, 64, 64]          147,456
│    │    └─BatchNorm2d: 3-23            [8, 128, 64, 64]          256
│    │    └─Sequential: 3-24             [8, 128, 64, 64]          8,448
│    │    └─ReLU: 3-25                   [8, 128, 64, 64]          --
│    └─BasicBlock: 2-5                   [8, 128, 64, 64]          --
│    │    └─Conv2d: 3-26                 [8, 128, 64, 64]          147,456
│    │    └─BatchNorm2d: 3-27            [8, 128, 64, 64]          256
│    │    └─ReLU: 3-28                   [8, 128, 64, 64]          --
│    │    └─Conv2d: 3-29                 [8, 128, 64, 64]          147,456
│    │    └─BatchNorm2d: 3-30            [8, 128, 64, 64]          256
│    │    └─ReLU: 3-31                   [8, 128, 64, 64]          --
│    └─BasicBlock: 2-6                   [8, 128, 64, 64]          --
│    │    └─Conv2d: 3-32                 [8, 128, 64, 64]          147,456
│    │    └─BatchNorm2d: 3-33            [8, 128, 64, 64]          256
│    │    └─ReLU: 3-34                   [8, 128, 64, 64]          --
│    │    └─Conv2d: 3-35                 [8, 128, 64, 64]          147,456
│    │    └─BatchNorm2d: 3-36            [8, 128, 64, 64]          256
│    │    └─ReLU: 3-37                   [8, 128, 64, 64]          --
│    └─BasicBlock: 2-7                   [8, 128, 64, 64]          --
│    │    └─Conv2d: 3-38                 [8, 128, 64, 64]          147,456
│    │    └─BatchNorm2d: 3-39            [8, 128, 64, 64]          256
│    │    └─ReLU: 3-40                   [8, 128, 64, 64]          --
│    │    └─Conv2d: 3-41                 [8, 128, 64, 64]          147,456
│    │    └─BatchNorm2d: 3-42            [8, 128, 64, 64]          256
│    │    └─ReLU: 3-43                   [8, 128, 64, 64]          --
│    └─BasicBlock: 2-8                   [8, 256, 32, 32]          --
│    │    └─Conv2d: 3-44                 [8, 256, 32, 32]          294,912
│    │    └─BatchNorm2d: 3-45            [8, 256, 32, 32]          512
│    │    └─ReLU: 3-46                   [8, 256, 32, 32]          --
│    │    └─Conv2d: 3-47                 [8, 256, 32, 32]          589,824
│    │    └─BatchNorm2d: 3-48            [8, 256, 32, 32]          512
│    │    └─Sequential: 3-49             [8, 256, 32, 32]          33,280
│    │    └─ReLU: 3-50                   [8, 256, 32, 32]          --
│    └─BasicBlock: 2-9                   [8, 256, 32, 32]          --
│    │    └─Conv2d: 3-51                 [8, 256, 32, 32]          589,824
│    │    └─BatchNorm2d: 3-52            [8, 256, 32, 32]          512
│    │    └─ReLU: 3-53                   [8, 256, 32, 32]          --
│    │    └─Conv2d: 3-54                 [8, 256, 32, 32]          589,824
│    │    └─BatchNorm2d: 3-55            [8, 256, 32, 32]          512
│    │    └─ReLU: 3-56                   [8, 256, 32, 32]          --
│    └─BasicBlock: 2-10                  [8, 256, 32, 32]          --
│    │    └─Conv2d: 3-57                 [8, 256, 32, 32]          589,824
│    │    └─BatchNorm2d: 3-58            [8, 256, 32, 32]          512
│    │    └─ReLU: 3-59                   [8, 256, 32, 32]          --
│    │    └─Conv2d: 3-60                 [8, 256, 32, 32]          589,824
│    │    └─BatchNorm2d: 3-61            [8, 256, 32, 32]          512
│    │    └─ReLU: 3-62                   [8, 256, 32, 32]          --
│    └─BasicBlock: 2-11                  [8, 256, 32, 32]          --
│    │    └─Conv2d: 3-63                 [8, 256, 32, 32]          589,824
│    │    └─BatchNorm2d: 3-64            [8, 256, 32, 32]          512
│    │    └─ReLU: 3-65                   [8, 256, 32, 32]          --
│    │    └─Conv2d: 3-66                 [8, 256, 32, 32]          589,824
│    │    └─BatchNorm2d: 3-67            [8, 256, 32, 32]          512
│    │    └─ReLU: 3-68                   [8, 256, 32, 32]          --
│    └─BasicBlock: 2-12                  [8, 256, 32, 32]          --
│    │    └─Conv2d: 3-69                 [8, 256, 32, 32]          589,824
│    │    └─BatchNorm2d: 3-70            [8, 256, 32, 32]          512
│    │    └─ReLU: 3-71                   [8, 256, 32, 32]          --
│    │    └─Conv2d: 3-72                 [8, 256, 32, 32]          589,824
│    │    └─BatchNorm2d: 3-73            [8, 256, 32, 32]          512
│    │    └─ReLU: 3-74                   [8, 256, 32, 32]          --
│    └─BasicBlock: 2-13                  [8, 256, 32, 32]          --
│    │    └─Conv2d: 3-75                 [8, 256, 32, 32]          589,824
│    │    └─BatchNorm2d: 3-76            [8, 256, 32, 32]          512
│    │    └─ReLU: 3-77                   [8, 256, 32, 32]          --
│    │    └─Conv2d: 3-78                 [8, 256, 32, 32]          589,824
│    │    └─BatchNorm2d: 3-79            [8, 256, 32, 32]          512
│    │    └─ReLU: 3-80                   [8, 256, 32, 32]          --
│    └─BasicBlock: 2-14                  [8, 512, 16, 16]          --
│    │    └─Conv2d: 3-81                 [8, 512, 16, 16]          1,179,648
│    │    └─BatchNorm2d: 3-82            [8, 512, 16, 16]          1,024
│    │    └─ReLU: 3-83                   [8, 512, 16, 16]          --
│    │    └─Conv2d: 3-84                 [8, 512, 16, 16]          2,359,296
│    │    └─BatchNorm2d: 3-85            [8, 512, 16, 16]          1,024
│    │    └─Sequential: 3-86             [8, 512, 16, 16]          132,096
│    │    └─ReLU: 3-87                   [8, 512, 16, 16]          --
│    └─BasicBlock: 2-15                  [8, 512, 16, 16]          --
│    │    └─Conv2d: 3-88                 [8, 512, 16, 16]          2,359,296
│    │    └─BatchNorm2d: 3-89            [8, 512, 16, 16]          1,024
│    │    └─ReLU: 3-90                   [8, 512, 16, 16]          --
│    │    └─Conv2d: 3-91                 [8, 512, 16, 16]          2,359,296
│    │    └─BatchNorm2d: 3-92            [8, 512, 16, 16]          1,024
│    │    └─ReLU: 3-93                   [8, 512, 16, 16]          --
│    └─BasicBlock: 2-16                  [8, 512, 16, 16]          --
│    │    └─Conv2d: 3-94                 [8, 512, 16, 16]          2,359,296
│    │    └─BatchNorm2d: 3-95            [8, 512, 16, 16]          1,024
│    │    └─ReLU: 3-96                   [8, 512, 16, 16]          --
│    │    └─Conv2d: 3-97                 [8, 512, 16, 16]          2,359,296
│    │    └─BatchNorm2d: 3-98            [8, 512, 16, 16]          1,024
│    │    └─ReLU: 3-99                   [8, 512, 16, 16]          --
├─ModuleList: 1                          --                        --
│    └─GradualStyleBlock: 2-17           [8, 1024]                 --
│    │    └─Sequential: 3-100            [8, 1024, 1, 1]           33,034,240
│    │    └─EqualLinear: 3-101           [8, 1024]                 1,049,600
│    └─GradualStyleBlock: 2-18           [8, 1024]                 --
│    │    └─Sequential: 3-102            [8, 1024, 1, 1]           33,034,240
│    │    └─EqualLinear: 3-103           [8, 1024]                 1,049,600
│    └─GradualStyleBlock: 2-19           [8, 1024]                 --
│    │    └─Sequential: 3-104            [8, 1024, 1, 1]           33,034,240
│    │    └─EqualLinear: 3-105           [8, 1024]                 1,049,600
│    └─GradualStyleBlock: 2-20           [8, 1024]                 --
│    │    └─Sequential: 3-106            [8, 1024, 1, 1]           33,034,240
│    │    └─EqualLinear: 3-107           [8, 1024]                 1,049,600
│    └─GradualStyleBlock: 2-21           [8, 1024]                 --
│    │    └─Sequential: 3-108            [8, 1024, 1, 1]           33,034,240
│    │    └─EqualLinear: 3-109           [8, 1024]                 1,049,600
│    └─GradualStyleBlock: 2-22           [8, 1024]                 --
│    │    └─Sequential: 3-110            [8, 1024, 1, 1]           33,034,240
│    │    └─EqualLinear: 3-111           [8, 1024]                 1,049,600
│    └─GradualStyleBlock: 2-23           [8, 1024]                 --
│    │    └─Sequential: 3-112            [8, 1024, 1, 1]           33,034,240
│    │    └─EqualLinear: 3-113           [8, 1024]                 1,049,600
│    └─GradualStyleBlock: 2-24           [8, 1024]                 --
│    │    └─Sequential: 3-114            [8, 1024, 1, 1]           33,034,240
│    │    └─EqualLinear: 3-115           [8, 1024]                 1,049,600
│    └─GradualStyleBlock: 2-25           [8, 1024]                 --
│    │    └─Sequential: 3-116            [8, 1024, 1, 1]           33,034,240
│    │    └─EqualLinear: 3-117           [8, 1024]                 1,049,600
│    └─GradualStyleBlock: 2-26           [8, 1024]                 --
│    │    └─Sequential: 3-118            [8, 1024, 1, 1]           33,034,240
│    │    └─EqualLinear: 3-119           [8, 1024]                 1,049,600
│    └─GradualStyleBlock: 2-27           [8, 1024]                 --
│    │    └─Sequential: 3-120            [8, 1024, 1, 1]           33,034,240
│    │    └─EqualLinear: 3-121           [8, 1024]                 1,049,600
│    └─GradualStyleBlock: 2-28           [8, 1024]                 --
│    │    └─Sequential: 3-122            [8, 1024, 1, 1]           33,034,240
│    │    └─EqualLinear: 3-123           [8, 1024]                 1,049,600
│    └─GradualStyleBlock: 2-29           [8, 1024]                 --
│    │    └─Sequential: 3-124            [8, 1024, 1, 1]           33,034,240
│    │    └─EqualLinear: 3-125           [8, 1024]                 1,049,600
│    └─GradualStyleBlock: 2-30           [8, 1024]                 --
│    │    └─Sequential: 3-126            [8, 1024, 1, 1]           33,034,240
│    │    └─EqualLinear: 3-127           [8, 1024]                 1,049,600
│    └─GradualStyleBlock: 2-31           [8, 1024]                 --
│    │    └─Sequential: 3-128            [8, 1024, 1, 1]           33,034,240
│    │    └─EqualLinear: 3-129           [8, 1024]                 1,049,600
│    └─GradualStyleBlock: 2-32           [8, 1024]                 --
│    │    └─Sequential: 3-130            [8, 1024, 1, 1]           33,034,240
│    │    └─EqualLinear: 3-131           [8, 1024]                 1,049,600
==========================================================================================
Total params: 566,635,584
Trainable params: 566,635,584
Non-trainable params: 0
Total mult-adds (G): 214.81
==========================================================================================
Input size (MB): 12.58
Forward/backward pass size (MB): 2254.44
Params size (MB): 2266.54
Estimated Total Size (MB): 4533.56
==========================================================================================

aravind598 avatar Aug 12 '22 18:08 aravind598

Yea, I believe you'll need to use a batch size of at least 2 to make all the dimensions work out. In terms of reducing the number of parameters in the GradualStyleBlock, what you could try doing is the following. Right now, each layer of the block is defined as follows:

[
    Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), 
    nn.LeakyReLU()
]

What you could try doing is add a max pool layer after each conv-relu layer and then you'll cut the number of layers in half. This could possibly harm the performance of the network, but will surely reduce the number of parameters.

yuval-alaluf avatar Aug 14 '22 06:08 yuval-alaluf

Ok much thanks for providing your input I will close this issue

aravind598 avatar Aug 15 '22 11:08 aravind598