vision
vision copied to clipboard
[feature proposal] U-Nets with pretrained torchvision backbones
I was thinking about extending torchvision models with an U-Net builder for segmentation, that takes pre-trained torchvision classification models as backbone architectures in the encoder path of the U-Net, and builds a decoder on top of it, using features from specified layers of the backbone model.
I already implemented this for ResNet, DenseNet and VGG models in a separate module: https://github.com/mkisantal/backboned-unet
Now I'm thinking about integrating it directly with torchvision. Do you think it would be a useful new feature?
It's not an addition to the available torchvision models in the traditional sense, as it just transforms the available models, does not work out of the box but requires training. But it can make torchvision easier to use for segmentation problems.
Hi,
I'm currently working on adding segmentation and detection models into torchvision, so this is a very timely feature proposal.
Here is an open PR for segmentation models, including FCN and DeepLabV3 https://github.com/pytorch/vision/pull/820
I still need to do quite some cleanup work there, but those models work and give expected accuracies within 2-3%.
I'll have a closer look at your implementation to see if we could incorporate a few of the ideas there into the current setup.
My original plan was to first get those 2 models merged, and then expand on U-Net.
Thoughts?
@fmassa As I see our implementations are fairly similar. I didn't change the backbone architecture (e.g. adding dilation like you did), so I was able to avoid having to subclass the backbone model. Therefore it works with DenseNet and VGG backbones too, not just ResNet.
Compared to FCN and DeepLabV3, U-Net is slightly different at a few parts (more upsampling steps, more backbone features are being feed to the "head"), but there are no fundamental differences I think. Adding U-Net after these two will be straightforward.
As I understood these reference scripts will be for reproducing the built-in models. So are you planning to release the models trained on some dataset (like pascal, coco or kitti)?
Initially I just proposed to have a builder that converts classification models for segmentation, but of course having pretrained segmentation models is even better.
Btw, is outplanes
a new attribute to nn.Module
? I had to rely on inferring the number of channels for the output features with a forward run...
What about HRNet? They are applied in different tasks and have already pytorch impl. See https://jingdongwang2017.github.io/Projects/HRNet/
@mkisantal We will be adding native support for dilation in resnet, so that we don't need to subclass it.
And indeed the reference scripts are meant to be used so that one can retrain (and reproduce) the models in the modelzoo, which will include segmentation models as well. They will be trained on coco / pascal for now.
About outplanes
, that was a quick addition that I did to ResNet
to make a few things simpler. I'm not yet clear on what's the best way of addressing this though.
@bhack we will be first looking into FCN and DeepLabV3, and we might expand on other models afterwards, thanks for the suggestion though!
@fmassa , hi, Francisco, I am reproducing the DeepLabV3+ model in PyTorch these days.
Hi! Any updates about U-net and/or HRNet?
Hello, any updates on adding U-Net to torchvision? Am interested