mmrazor icon indicating copy to clipboard operation
mmrazor copied to clipboard

How to migrate mgd from mmdet to mmseg

Open kaizhong2021 opened this issue 1 year ago • 1 comments

Describe the question you meet

[I want to migrate mgd from mmdet to mmseg ,I mitate the cwd but have almost no effect after migration]

Post related information

  1. The output of pip list | grep "mmcv\|mmrazor\|^torch" \mmcv 2.0.0 mmcv-full 1.7.1 mmrazor 1.0.0 /Users/chaikaizhong/work/mmrazor torch 1.13.1 torchaudio 0.13.1 torchvision 0.14.1 \
  2. Your config file if you modified it or created a new one.
_base_ = [
    'mmseg::_base_/datasets/cityscapes.py',
    'mmseg::_base_/schedules/schedule_80k.py',
    'mmseg::_base_/default_runtime.py'
]

teacher_ckpt = 'https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101' \
               '-d8_512x1024_80k_cityscapes/deeplabv3plus_r101-d8_512x1024_80k_cityscapes_20200606_114143-068fcfe9.pth'
teacher_cfg_path = 'mmseg::deeplabv3plus/deeplabv3plus_r101-d8_4xb2-80k_cityscapes-512x1024.py'
student_cfg_path = 'mmseg::stdc/stdc2-fastscnn.py'
model = dict(
    _scope_='mmrazor',
    type='SingleTeacherDistill',
    architecture=dict(cfg_path=student_cfg_path,pretrained=False),
    teacher=dict(cfg_path=teacher_cfg_path, pretrained=False),
    teacher_ckpt=teacher_ckpt,
    distiller=dict(
        type='ConfigurableDistiller',
        distill_losses=dict(
            loss_mgd=dict(type='MGDLoss', alpha_mgd=0.00002)),
        student_recorders=dict(
            logits=dict(type='ModuleOutputs', source='decode_head.conv_seg')),
        teacher_recorders=dict(
            logits=dict(type='ModuleOutputs', source='decode_head.conv_seg')),
        loss_forward_mappings=dict(
            loss_mgd=dict(
                preds_S=dict(from_student=True, recorder='logits'),
                preds_T=dict(from_student=False, recorder='logits')))))

find_unused_parameters = True

val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')


_base_ = [
    'mmseg::_base_/datasets/cityscapes.py',
    'mmseg::_base_/schedules/schedule_80k.py',
    'mmseg::_base_/default_runtime.py'
]

teacher_ckpt = '/root/mmsegmentation/iter_160000.pth'
teacher_cfg_path = 'mmseg::stdc/fast_scnn_8xb4-160k_cityscapes-512x1024_stdc2pre_stdc2_in1k-pre.py'
student_cfg_path = 'mmseg::stdc/fast_scnn_8xb4-160k_cityscapes-512x1024_stdc.py'
model = dict(
    _scope_='mmrazor',
    type='SingleTeacherDistill',
    architecture=dict(cfg_path=student_cfg_path,pretrained=False),
    teacher=dict(cfg_path=teacher_cfg_path, pretrained=False),
    teacher_ckpt=teacher_ckpt,
    distiller=dict(
        type='ConfigurableDistiller',
        distill_losses=dict(
            loss_mgd_fpn0=dict(type='MGDLoss', alpha_mgd=0.00002),
            loss_mgd_fpn1=dict(type='MGDLoss', alpha_mgd=0.00002),
            loss_mgd_fpn2=dict(type='MGDLoss', alpha_mgd=0.00002),
            loss_mgd_fpn3=dict(type='MGDLoss', alpha_mgd=0.00002),
            loss_mgd_fpn4=dict(type='MGDLoss', alpha_mgd=0.00002),
            loss_mgd_fpn5=dict(type='ChannelWiseDivergence',tau=1, loss_weight=5)),
        student_recorders=dict(
            fpn0 = dict(type='ModuleOutputs', source='backbone.arms.1.atten_conv_layer.1.conv'),
            fpn1 = dict(type='ModuleOutputs', source='backbone.conv_avg.conv'),
            fpn2 = dict(type='ModuleOutputs', source='backbone.ffm.attention.2.conv'),
            fpn3 = dict(type='ModuleOutputs', source='decode_head.conv_seg'),
            fpn4 = dict(type='ModuleOutputs', source='auxiliary_head.1.convs.0.conv'),
            fpn5 = dict(type='ModuleOutputs', source='backbone.conv_avg.conv')),
        teacher_recorders=dict(
            fpn0=dict(type='ModuleOutputs', source='backbone.arms.1.atten_conv_layer.1.conv'),
            fpn1=dict(type='ModuleOutputs', source='backbone.conv_avg.conv'),
            fpn2=dict(type='ModuleOutputs', source='backbone.ffm.attention.2.conv'),
            fpn3=dict(type='ModuleOutputs', source='decode_head.conv_seg'),
            fpn4=dict(type='ModuleOutputs', source='auxiliary_head.1.convs.0.conv'),
            fpn5 = dict(type='ModuleOutputs', source='backbone.conv_avg.conv')),
            connectors=dict(
                s_fpn0_connector=dict(
                type='MGDConnector',
                student_channels=128,
                teacher_channels=256,
                lambda_mgd=0.65),
            s_fpn1_connector=dict(
                type='MGDConnector',
                student_channels=128,
                teacher_channels=256,
                lambda_mgd=0.65),
            s_fpn2_connector=dict(
                type='MGDConnector',
                student_channels=256,
                teacher_channels=256,
                lambda_mgd=0.65),
            s_fpn3_connector=dict(
                type='MGDConnector',
                student_channels=19,
                teacher_channels=256,
                lambda_mgd=0.65),
            s_fpn4_connector=dict(
                type='MGDConnector',
                student_channels=64,
                teacher_channels=256,
                lambda_mgd=0.65)),
        loss_forward_mappings=dict(
            loss_mgd_fpn0=dict(
                preds_S=dict(
                    from_student=True,
                    recorder='fpn0',
                    connector='s_fpn0_connector'),
                preds_T=dict(from_student=False, recorder='fpn0')),
            loss_mgd_fpn1=dict(
                preds_S=dict(
                    from_student=True,
                    recorder='fpn1',
                    connector='s_fpn1_connector'),
                preds_T=dict(from_student=False, recorder='fpn1')),
            loss_mgd_fpn2=dict(
                preds_S=dict(
                    from_student=True,
                    recorder='fpn2',
                    connector='s_fpn2_connector'),
                preds_T=dict(from_student=False, recorder='fpn2')),
            loss_mgd_fpn3=dict(
                preds_S=dict(
                    from_student=True,
                    recorder='fpn3',
                    connector='s_fpn3_connector'),
                preds_T=dict(from_student=False, recorder='fpn3')),
            loss_mgd_fpn4=dict(
                preds_S=dict(
                    from_student=True,
                    recorder='fpn4',
                    connector='s_fpn4_connector'),
                preds_T=dict(from_student=False, recorder='fpn4')))))

find_unused_parameters = True

val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')

optimizer_config = dict(
    _delete_=True, grad_clip=dict(max_norm=35, norm_type=2))

param_scheduler = [
    dict(
        type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500),
    dict(
        type='MultiStepLR',
        begin=0,
        end=24,
        by_epoch=True,
        milestones=[16, 22],
        gamma=0.1)
]

optim_wrapper = dict(
    optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001))


kaizhong2021 avatar Aug 28 '23 08:08 kaizhong2021

image

kaizhong2021 avatar Aug 30 '23 12:08 kaizhong2021