mmsegmentation icon indicating copy to clipboard operation
mmsegmentation copied to clipboard

is there a way to change backbone of deeplabv3+ from resnet to swin transformer?

Open francishcheng opened this issue 2 years ago • 4 comments

i wanted to do this but failed

francishcheng avatar Jul 21 '22 05:07 francishcheng

I guess it is caused by different shape of feature map which is output of backbone. If you want to use Swin Transformer backbone and DeepLabV3+ decoder head, you need to modify parameters in config file to ensure shape of feature map is OK for different modules in model.

MengzhangLI avatar Jul 21 '22 06:07 MengzhangLI

I guess it is caused by different shape of feature map which is output of backbone. If you want to use Swin Transformer backbone and DeepLabV3+ decoder head, you need to modify parameters in config file to ensure shape of feature map is OK for different modules in model.

i don't think its possible, it differs so much

my code: config file:

norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
    type='EncoderDecoder',
    pretrained='pretrain/swin_base_patch4_window7_224_mmseg.pth',
    backbone=dict(
        type='SwinTransformer',
        pretrain_img_size=224,
        embed_dims=128,
        patch_size=4,
        window_size=7,
        mlp_ratio=4,
        depths=[2, 2, 18, 2],
        num_heads=[4, 8, 16, 32],
        strides=(4, 2, 2, 2),
        out_indices=(0, 1, 2, 3),
        qkv_bias=True,
        qk_scale=None,
        patch_norm=True,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.3,
        use_abs_pos_embed=False,
        act_cfg=dict(type='GELU'),
        norm_cfg=dict(type='LN', requires_grad=True)),
    decode_head=dict(
        type='DepthwiseSeparableASPPHead',
        in_channels=2048,
        in_index=3,
        channels=512,
        dilations=(1, 12, 24, 36),
        c1_in_channels=256,
        c1_channels=48,
        dropout_ratio=0.1,
        num_classes=2,
        norm_cfg=dict(type='SyncBN', requires_grad=True),
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
    auxiliary_head=dict(
        type='FCNHead',
        in_channels=1024,
        in_index=2,
        channels=256,
        num_convs=1,
        concat_input=False,
        dropout_ratio=0.1,
        num_classes=2,
        norm_cfg=dict(type='SyncBN', requires_grad=True),
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
    train_cfg=dict(),
    test_cfg=dict(mode='whole'))

error:

 File "/home/u/Desktop/mmsegmentation/mmseg/models/decode_heads/decode_head.py", line 203, in forward_train
    seg_logits = self.forward(inputs)
  File "/home/u/Desktop/mmsegmentation/mmseg/models/decode_heads/sep_aspp_head.py", line 84, in forward
    self.image_pool(x),
  File "/home/u/miniconda3/envs/mmseg/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/u/miniconda3/envs/mmseg/lib/python3.8/site-packages/torch/nn/modules/container.py", line 141, in forward
    input = module(input)
  File "/home/u/miniconda3/envs/mmseg/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/u/miniconda3/envs/mmseg/lib/python3.8/site-packages/mmcv/cnn/bricks/conv_module.py", line 201, in forward
    x = self.conv(x)
  File "/home/u/miniconda3/envs/mmseg/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/u/miniconda3/envs/mmseg/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 447, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/u/miniconda3/envs/mmseg/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 443, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [512, 2048, 1, 1], expected input[2, 1024, 1, 1] to have 2048 channels, but got 1024 channels instead

francishcheng avatar Jul 21 '22 06:07 francishcheng

It needs many modifications in config file (perhaps also needs modification on DepthwiseSeparableASPPHead). My answer is based on your insist on Swin + DeeplabV3+ implementation: you should checkout shape of output feature map of backbone first, and you should know how DepthwiseSeparableASPPHead handles input feature map.

The output of ResNet50-D8, which is widely used for Deeplabv3+: image

The output of Swin tiny: image

From the snapshot we could find the number of channels, width and height is 2 times or identical, so I think it may be possible, the keypoint is ensuring shape of backbone output feature map could be handled by our customized decoder head.

MengzhangLI avatar Jul 21 '22 07:07 MengzhangLI

Hi @francishcheng Were you able to get it to work? Were you able to use any other backbone which gets higher accuracy that ResNet?

fschvart avatar Jul 28 '22 12:07 fschvart

An example config:

dict(
    type='EncoderDecoder',
    data_preprocessor = dict(
        type='SegDataPreProcessor',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        bgr_to_rgb=True,
        pad_val=0,
        seg_pad_val=255,
        size=(512, 512)),
    backbone=dict(
        type='SwinTransformer',
        pretrain_img_size=224,
        embed_dims=96,
        patch_size=4,
        window_size=7,
        mlp_ratio=4,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        strides=(4, 2, 2, 2),
        out_indices=(0, 1, 2, 3),
        qkv_bias=True,
        qk_scale=None,
        patch_norm=True,
        drop_rate=0.,
        attn_drop_rate=0.,
        drop_path_rate=0.3,
        use_abs_pos_embed=False,
        act_cfg=dict(type='GELU'),
        norm_cfg=backbone_norm_cfg),
    decode_head=dict(
        type='DepthwiseSeparableASPPHead',
        in_channels=768,
        in_index=3,
        channels=512,
        dilations=(1, 12, 24, 36),
        c1_in_channels=96,
        c1_channels=48,
        dropout_ratio=0.1,
        num_classes=19,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
    )
)

xiexinch avatar Mar 16 '23 11:03 xiexinch

Closing the issue, as there is no activity for a while. We hope your issue has been resolved. If not, please feel free to open a new one.

xiexinch avatar Mar 16 '23 11:03 xiexinch