revisiting-sepconv
revisiting-sepconv copied to clipboard
Training procedure
Hi, I was trying to retrain your model as would like to conduct some experiments on this, however, I've got a few questions about the training process of SepConv++,
For training:
- You train using batch size 16, Adamax optimizer with lr=0.001 which halves at epoch 60 and 80 and train on vimeo without any validation set
- In the paper you said you use relu1_2 of VGG, but you didn't specify which one, is it VGG16 or 19?
- I used VGG19 and defined it as the following:
vgg19 = torchvision.models.vgg19(pretrained=True)
self.vgg = torch.nn.Sequential(*list(vgg19.children())[0][:4])
for param in self.vgg.parameters():
param.requires_grad = False
and to get contextual output:
tenOut2 = sepconv.sepconv_func.apply(self.vgg(tenOne), tenVerone, tenHorone) + sepconv.sepconv_func.apply(self.vgg(tenTwo), tenVertwo, tenHortwo)
This gives a 64 channel output. Is this correct? When using the VGG, you can’t add the additional channel to the input to make it 4 channels, so I proceed to apply context on the 3 channel input images
-Do you normalize this output as well? If you do, you get 63 channels unless you remove [:, -1:, :, :]. So here's what I did to normalize without losing that channel:
tenNormalize = tenOut2
tenNormalize[tenNormalize.abs() < 0.01] = 1.0
tenOut2 = tenOut2/ tenNormalize
-
When you normalize the output, is this done only during testing or also during training?
-
The contextual loss is
L1(output,gt) + 0.1*L1(contextual_output, contextual_gt)
?
Your help would be appreciated on this as the model isn't converging which leads me to believe I've made some sort of mistake somewhere
You train using batch size 16, Adamax optimizer with lr=0.001
I don't believe we outlined this in our paper. Anyways, we used a batch size of 8 but whether you use 8 or 16 shouldn't matter too much. As for the learning rate, try a few and see what works best for you.
In the paper you said you use relu1_2 of VGG, but you didn't specify which one, is it VGG16 or 19?
We use torchvision.models.vgg16_bn(pretrained=True).features[0:6]
, but I doubt it makes much of a difference.
When using the VGG, you can’t add the additional channel to the input to make it 4 channels, so I proceed to apply context on the 3 channel input images.
You can still use normalization. Just like with appending an auxiliary channel when operating on a 3-channel RGB input, you can add an auxiliary channel when operating on a 64-channel VGG input. That is roughly what we do:
tenVggone = vgg16(tenOne)
tenVggtwo = vgg16(tenTwo)
tenVggone = torch.cat([tenVggone, tenVggone.new_ones([tenVggone.shape[0], 1, tenVggone.shape[2], tenVggone.shape[3]])], 1).detach()
tenVggtwo = torch.cat([tenVggtwo, tenVggtwo.new_ones([tenVggtwo.shape[0], 1, tenVggtwo.shape[2], tenVggtwo.shape[3]])], 1).detach()
tenVggestimate = adasep(tenVggone, tenVerone, tenHorone) + adasep(tenVggtwo, tenVertwo, tenHortwo)
tenNormalize = tenVggestimate[:, -1:, :, :]
tenNormalize[tenNormalize.abs() < 0.01] = 1.0
tenVggestimate = tenVggestimate[:, :-1, :, :] / tenNormalize
torch.nn.functional.l1_loss(input=tenVggestimate, target=tenVggtruth, reduction='mean')
When you normalize the output, is this done only during testing or also during training?
Apologies, but I am afraid that I don't understand the question. If it is referring to "3.4 Kernel Normalization" then yeah, that is done during training and inference.
The contextual loss is L1(output,gt) + 0.1*L1(contextual_output, contextual_gt)?
Almost, we use a Laplacian pyramid loss between output
and gt
but I doubt it makes too much of a difference.
the model isn't converging
Try without the contextual loss, find the learning rate that works best for you, and then introduce the contextual loss to make sure there is no issue with that.
Yes, you are right, I assumed you followed a similar training procedure to SepConv (hench batch size 16). Will try out your suggestions and see how the network performs! Thanks for your help!
Closing this due to inactivity, feel free to reopen in case you have any other questions - thanks!