pytorch-semseg
pytorch-semseg copied to clipboard
is_deconv=False seems to break UNet model?
I added this code snippet to the UNet model so that I can use bilinear interpolation:
if __name__=="__main__":
x = torch.randn([5, 3, 300, 300])
model = unet(is_deconv=False)
y = model(x)
My error output is as shown below:
File "/home/sreenivas/sandbox/UNet/pytorch-semseg/ptsemseg/models/unet.py", line 70, in forward
up4 = self.up_concat4(conv4, center)
File "/home/sreenivas/.envs/thesis/local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "utils.py", line 219, in forward
return self.conv(torch.cat([outputs1, outputs2], 1))
File "/home/sreenivas/.envs/thesis/local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "utils.py", line 200, in forward
outputs = self.conv1(inputs)
File "/home/sreenivas/.envs/thesis/local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/home/sreenivas/.envs/thesis/local/lib/python2.7/site-packages/torch/nn/modules/container.py", line 91, in forward
input = module(input)
File "/home/sreenivas/.envs/thesis/local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/home/sreenivas/.envs/thesis/local/lib/python2.7/site-packages/torch/nn/modules/conv.py", line 301, in forward
self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight[128, 256, 3, 3], so expected input[5, 384, 22, 22] to have 256 channels, but got 384 channels instead
The channels from skip connection are 256 and 128 before concatenation. Seems like 1x1 convolution should precede or succed bilinear interpolation. I can submit a patch, but which order is preferred?
Thanks for this awesome repo!