revisiting-sepconv icon indicating copy to clipboard operation
revisiting-sepconv copied to clipboard

Training procedure

Open issakh opened this issue 2 years ago • 2 comments

Hi, I was trying to retrain your model as would like to conduct some experiments on this, however, I've got a few questions about the training process of SepConv++,

For training:

  • You train using batch size 16, Adamax optimizer with lr=0.001 which halves at epoch 60 and 80 and train on vimeo without any validation set
  • In the paper you said you use relu1_2 of VGG, but you didn't specify which one, is it VGG16 or 19?
  • I used VGG19 and defined it as the following:
vgg19 = torchvision.models.vgg19(pretrained=True)
self.vgg = torch.nn.Sequential(*list(vgg19.children())[0][:4])
for param in self.vgg.parameters():
param.requires_grad = False

and to get contextual output:

tenOut2 = sepconv.sepconv_func.apply(self.vgg(tenOne), tenVerone, tenHorone) + sepconv.sepconv_func.apply(self.vgg(tenTwo), tenVertwo, tenHortwo) This gives a 64 channel output. Is this correct? When using the VGG, you can’t add the additional channel to the input to make it 4 channels, so I proceed to apply context on the 3 channel input images

-Do you normalize this output as well? If you do, you get 63 channels unless you remove [:, -1:, :, :]. So here's what I did to normalize without losing that channel:

tenNormalize = tenOut2
tenNormalize[tenNormalize.abs() < 0.01] = 1.0
tenOut2 = tenOut2/ tenNormalize

  • When you normalize the output, is this done only during testing or also during training?

  • The contextual loss is L1(output,gt) + 0.1*L1(contextual_output, contextual_gt)?

Your help would be appreciated on this as the model isn't converging which leads me to believe I've made some sort of mistake somewhere

issakh avatar May 17 '22 13:05 issakh

You train using batch size 16, Adamax optimizer with lr=0.001

I don't believe we outlined this in our paper. Anyways, we used a batch size of 8 but whether you use 8 or 16 shouldn't matter too much. As for the learning rate, try a few and see what works best for you.

In the paper you said you use relu1_2 of VGG, but you didn't specify which one, is it VGG16 or 19?

We use torchvision.models.vgg16_bn(pretrained=True).features[0:6], but I doubt it makes much of a difference.

When using the VGG, you can’t add the additional channel to the input to make it 4 channels, so I proceed to apply context on the 3 channel input images.

You can still use normalization. Just like with appending an auxiliary channel when operating on a 3-channel RGB input, you can add an auxiliary channel when operating on a 64-channel VGG input. That is roughly what we do:

tenVggone = vgg16(tenOne)
tenVggtwo = vgg16(tenTwo)

tenVggone = torch.cat([tenVggone, tenVggone.new_ones([tenVggone.shape[0], 1, tenVggone.shape[2], tenVggone.shape[3]])], 1).detach()
tenVggtwo = torch.cat([tenVggtwo, tenVggtwo.new_ones([tenVggtwo.shape[0], 1, tenVggtwo.shape[2], tenVggtwo.shape[3]])], 1).detach()

tenVggestimate = adasep(tenVggone, tenVerone, tenHorone) + adasep(tenVggtwo, tenVertwo, tenHortwo)

tenNormalize = tenVggestimate[:, -1:, :, :]
tenNormalize[tenNormalize.abs() < 0.01] = 1.0
tenVggestimate = tenVggestimate[:, :-1, :, :] / tenNormalize

torch.nn.functional.l1_loss(input=tenVggestimate, target=tenVggtruth, reduction='mean')

When you normalize the output, is this done only during testing or also during training?

Apologies, but I am afraid that I don't understand the question. If it is referring to "3.4 Kernel Normalization" then yeah, that is done during training and inference.

The contextual loss is L1(output,gt) + 0.1*L1(contextual_output, contextual_gt)?

Almost, we use a Laplacian pyramid loss between output and gt but I doubt it makes too much of a difference.

the model isn't converging

Try without the contextual loss, find the learning rate that works best for you, and then introduce the contextual loss to make sure there is no issue with that.

sniklaus avatar May 19 '22 23:05 sniklaus

Yes, you are right, I assumed you followed a similar training procedure to SepConv (hench batch size 16). Will try out your suggestions and see how the network performs! Thanks for your help!

issakh avatar May 22 '22 15:05 issakh

Closing this due to inactivity, feel free to reopen in case you have any other questions - thanks!

sniklaus avatar Nov 07 '22 19:11 sniklaus