mmrazor icon indicating copy to clipboard operation
mmrazor copied to clipboard

[Bug] Checkpoint fails with "with_cp=True" in mmdet swin-transformer distillation.

Open Mrxiaoyuer opened this issue 2 years ago • 3 comments

Describe the bug

A clear and concise description of what the bug is.

[Hi Developers, Thanks for your efforts. I was trying to distill swint-ransformers and use the (with_cp=True) to train with larger batch size. But it turns out some errors as attached below.]

To Reproduce

[python ./tools/mmdet/train_mmdet.py ./configs/distill/cwd/my_config.py]

Post related information

  1. The output of pip list | grep "mmcv\|mmrazor\|^torch" [here]
  2. Your config file if you modified it or created a new one.
[_base_ = [
    '../../_base_/datasets/mmdet/coco_detection.py',
    '../../_base_/schedules/mmdet/schedule_1x.py',
    '../../_base_/mmdet_runtime.py'
]

pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth'  # noqa

student = dict(
    type='mmdet.MaskRCNN',
        backbone=dict(
        type='SwinTransformer',
        embed_dims=96,
        depths=[2, 2, 18, 2],
        num_heads=[3, 6, 12, 24],
        window_size=7,
        mlp_ratio=4,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.,
        attn_drop_rate=0.,
        drop_path_rate=0.2,
        patch_norm=True,
        out_indices=(0, 1, 2, 3),
        with_cp=True,
        convert_weights=True,
        init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
    neck=dict(
        type='FPN',
        in_channels=[96, 192, 384, 768],
        out_channels=256,
        num_outs=5),
    rpn_head=dict(
        type='RPNHead',
        in_channels=256,
        feat_channels=256,
        anchor_generator=dict(
            type='AnchorGenerator',
            scales=[8],
            ratios=[0.5, 1.0, 2.0],
            strides=[4, 8, 16, 32, 64]),
        bbox_coder=dict(
            type='DeltaXYWHBBoxCoder',
            target_means=[.0, .0, .0, .0],
            target_stds=[1.0, 1.0, 1.0, 1.0]),
        loss_cls=dict(
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
    roi_head=dict(
        type='StandardRoIHead',
        bbox_roi_extractor=dict(
            type='SingleRoIExtractor',
            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
            out_channels=256,
            featmap_strides=[4, 8, 16, 32]),
        bbox_head=dict(
            type='Shared2FCBBoxHead',
            in_channels=256,
            fc_out_channels=1024,
            roi_feat_size=7,
            num_classes=80,
            bbox_coder=dict(
                type='DeltaXYWHBBoxCoder',
                target_means=[0., 0., 0., 0.],
                target_stds=[0.1, 0.1, 0.2, 0.2]),
            reg_class_agnostic=False,
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
            loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
        ),
    # model training and testing settings
    train_cfg=dict(
        rpn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.7,
                neg_iou_thr=0.3,
                min_pos_iou=0.3,
                match_low_quality=True,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=256,
                pos_fraction=0.5,
                neg_pos_ub=-1,
                add_gt_as_proposals=False),
            allowed_border=-1,
            pos_weight=-1,
            debug=False),
        rpn_proposal=dict(
            nms_pre=2000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.5,
                neg_iou_thr=0.5,
                min_pos_iou=0.5,
                match_low_quality=True,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=512,
                pos_fraction=0.25,
                neg_pos_ub=-1,
                add_gt_as_proposals=True),
            mask_size=28,
            pos_weight=-1,
            debug=False)),
    test_cfg=dict(
        rpn=dict(
            nms_pre=1000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            score_thr=0.05,
            nms=dict(type='nms', iou_threshold=0.5),
            max_per_img=100,
            mask_thr_binary=0.5)))

teacher_ckpt = 'https://github.com/Mrxiaoyuer/coco_checkpoints/releases/download/v0.0.2/swinl_epoch_24.pth'  # noqa

teacher = dict(
    type='mmdet.MaskRCNN',
    init_cfg=dict(type='Pretrained', checkpoint=teacher_ckpt),
    backbone=dict(
        type='SwinTransformer',
        embed_dims=192,
        depths=[2, 2, 18, 2],
        num_heads=[6, 12, 24, 48],
        window_size=7,
        mlp_ratio=4,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.,
        attn_drop_rate=0.,
        drop_path_rate=0.2,
        patch_norm=True,
        out_indices=(0, 1, 2, 3),
        with_cp=False,
        convert_weights=True,
        init_cfg=None),
    neck=dict(
        type='FPN',
        in_channels=[192, 384, 768, 1536],
        out_channels=256,
        start_level=0,
        add_extra_convs='on_output',
        num_outs=5),
    rpn_head=dict(
        type='RPNHead',
        in_channels=256,
        feat_channels=256,
        anchor_generator=dict(
            type='AnchorGenerator',
            scales=[8],
            ratios=[0.5, 1.0, 2.0],
            strides=[4, 8, 16, 32, 64]),
        bbox_coder=dict(
            type='DeltaXYWHBBoxCoder',
            target_means=[.0, .0, .0, .0],
            target_stds=[1.0, 1.0, 1.0, 1.0]),
        loss_cls=dict(
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
    roi_head=dict(
        type='StandardRoIHead',
        bbox_roi_extractor=dict(
            type='SingleRoIExtractor',
            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
            out_channels=256,
            featmap_strides=[4, 8, 16, 32]),
        bbox_head=dict(
            type='Shared2FCBBoxHead',
            in_channels=256,
            fc_out_channels=1024,
            roi_feat_size=7,
            num_classes=80,
            bbox_coder=dict(
                type='DeltaXYWHBBoxCoder',
                target_means=[0., 0., 0., 0.],
                target_stds=[0.1, 0.1, 0.2, 0.2]),
            reg_class_agnostic=False,
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
            loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
        ),
    # model training and testing settings
    train_cfg=dict(
        rpn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.7,
                neg_iou_thr=0.3,
                min_pos_iou=0.3,
                match_low_quality=True,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=256,
                pos_fraction=0.5,
                neg_pos_ub=-1,
                add_gt_as_proposals=False),
            allowed_border=-1,
            pos_weight=-1,
            debug=False),
        rpn_proposal=dict(
            nms_pre=2000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.5,
                neg_iou_thr=0.5,
                min_pos_iou=0.5,
                match_low_quality=True,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=512,
                pos_fraction=0.25,
                neg_pos_ub=-1,
                add_gt_as_proposals=True),
            mask_size=28,
            pos_weight=-1,
            debug=False)),
    test_cfg=dict(
        rpn=dict(
            nms_pre=1000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            score_thr=0.05,
            nms=dict(type='nms', iou_threshold=0.5),
            max_per_img=100,
            mask_thr_binary=0.5)))


# algorithm setting
algorithm = dict(
    type='GeneralDistill',
    architecture=dict(
        type='MMDetArchitecture',
        model=student,
    ),
    distiller=dict(
        type='SingleTeacherDistiller',
        teacher=teacher,
        teacher_trainable=False,
        components=[
            dict(
                student_module='neck.fpn_convs.3.conv',
                teacher_module='neck.fpn_convs.3.conv',
                losses=[
                    dict(
                        type='ChannelWiseDivergence',
                        name='loss_cwd_cls_head',
                        tau=1,
                        loss_weight=5,
                    )
                ])
        ]),
)

find_unused_parameters = True

# optimizer
optimizer = dict(
    _delete_=True,
    type='AdamW',
    lr=0.0001,
    betas=(0.9, 0.999),
    weight_decay=0.05,
    paramwise_cfg=dict(
        custom_keys={
            'absolute_pos_embed': dict(decay_mult=0.),
            'relative_position_bias_table': dict(decay_mult=0.),
            'norm': dict(decay_mult=0.)
        }))
lr_config = dict(warmup_iters=1000, step=[8, 11])
runner = dict(max_epochs=12)

  1. Your train log file if you meet the problem during training. [loading annotations into memory... Done (t=51.46s) creating index... Done (t=51.98s) creating index... Done (t=51.93s) creating index... Done (t=51.97s) creating index... index created! index created! index created! index created! loading annotations into memory... loading annotations into memory... loading annotations into memory... loading annotations into memory... Done (t=1.01s) creating index... Done (t=1.24s) creating index... Done (t=1.25s) creating index... Done (t=1.02s) creating index... index created! index created! index created! index created! 2022-08-08 20:44:14,703 - mmdet - INFO - Start running, host: root@da8f74b40bf146bf925ac3a4349ccef80000Z0, work_dir: /mnt/azureml/cr/j/89a4752b1bff40f9951561de40f847a6/cap/data-capability/wd/INPUT_3324fea54867408dbec17da2dfe673f2_data/fuxun/output/fuxun-mmrazor-cwd-swin-l-s_1x_adamw_2xbs/cwd_cls_head_swin-l-s_1x_coco_adamw_2xbs 2022-08-08 20:44:14,704 - mmdet - INFO - Hooks will be executed in the following order: before_run: (VERY_HIGH ) StepLrUpdaterHook
    (ABOVE_NORMAL) Fp16OptimizerHook
    (NORMAL ) CheckpointHook
    (LOW ) DistEvalHook
    (VERY_LOW ) TextLoggerHook

before_train_epoch: (VERY_HIGH ) StepLrUpdaterHook
(NORMAL ) NumClassCheckHook
(NORMAL ) DistSamplerSeedHook
(LOW ) IterTimerHook
(LOW ) DistEvalHook
(VERY_LOW ) TextLoggerHook

before_train_iter: (VERY_HIGH ) StepLrUpdaterHook
(LOW ) IterTimerHook
(LOW ) DistEvalHook

after_train_iter: (ABOVE_NORMAL) Fp16OptimizerHook
(NORMAL ) CheckpointHook
(LOW ) IterTimerHook
(LOW ) DistEvalHook
(VERY_LOW ) TextLoggerHook

after_train_epoch: (NORMAL ) CheckpointHook
(LOW ) DistEvalHook
(VERY_LOW ) TextLoggerHook

before_val_epoch: (NORMAL ) NumClassCheckHook
(NORMAL ) DistSamplerSeedHook
(LOW ) IterTimerHook
(VERY_LOW ) TextLoggerHook

before_val_iter: (LOW ) IterTimerHook

after_val_iter: (LOW ) IterTimerHook

after_val_epoch: (VERY_LOW ) TextLoggerHook

after_run: (VERY_LOW ) TextLoggerHook

2022-08-08 20:44:14,704 - mmdet - INFO - workflow: [('train', 1)], max: 12 epochs 2022-08-08 20:44:14,710 - mmdet - INFO - Checkpoints will be saved to /mnt/azureml/cr/j/89a4752b1bff40f9951561de40f847a6/cap/data-capability/wd/INPUT_3324fea54867408dbec17da2dfe673f2_data/fuxun/output/fuxun-mmrazor-cwd-swin-l-s_1x_adamw_2xbs/cwd_cls_head_swin-l-s_1x_coco_adamw_2xbs by HardDiskBackend. /opt/conda/lib/python3.7/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1623448265233/work/c10/core/TensorImpl.h:1156.) return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) /opt/conda/lib/python3.7/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1623448265233/work/c10/core/TensorImpl.h:1156.) return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) /opt/conda/lib/python3.7/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1623448265233/work/c10/core/TensorImpl.h:1156.) return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) /opt/conda/lib/python3.7/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1623448265233/work/c10/core/TensorImpl.h:1156.) return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) Traceback (most recent call last): File "tools/mmdet/train_mmdet.py", line 230, in main() File "tools/mmdet/train_mmdet.py", line 226, in main meta=meta) File "/tmp/code/mmrazor/apis/mmdet/train.py", line 206, in train_mmdet_model runner.run(data_loader, cfg.workflow) File "/opt/conda/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 127, in run epoch_runner(data_loaders[i], **kwargs) File "/opt/conda/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 51, in train self.call_hook('after_train_iter') File "/opt/conda/lib/python3.7/site-packages/mmcv/runner/base_runner.py", line 309, in call_hook getattr(hook, fn_name)(self) File "/opt/conda/lib/python3.7/site-packages/mmcv/runner/hooks/optimizer.py", line 272, in after_train_iter self.loss_scaler.scale(runner.outputs['loss']).backward() File "/opt/conda/lib/python3.7/site-packages/torch/_tensor.py", line 255, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) File "/opt/conda/lib/python3.7/site-packages/torch/autograd/init.py", line 149, in backward allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag File "/opt/conda/lib/python3.7/site-packages/torch/autograd/function.py", line 87, in apply return self._forward_cls.backward(self, *args) # type: ignore[attr-defined] File "/opt/conda/lib/python3.7/site-packages/torch/utils/checkpoint.py", line 138, in backward torch.autograd.backward(outputs_with_grad, args_with_grad) File "/opt/conda/lib/python3.7/site-packages/torch/autograd/init.py", line 149, in backward allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the forward function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple checkpoint functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations. Parameter at index 323 has been marked as ready twice. This means that multiple autograd engine hooks have fired for this particular parameter during this iteration. You can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print parameter names for further debugging.] 4. Other code you modified in the mmrazor folder. [here]

Additional context

Add any other context about the problem here.

[(1) I tried add (with_cp=True) in the r101_r50 cwd distillation and it works fine. So this is probably related to the swin transformer structure (2) I also tried (with_cp=True) by training swin-transformer alone in the mmdetection repo and it works fine. ]

Mrxiaoyuer avatar Aug 08 '22 21:08 Mrxiaoyuer

Hi, Mrxiaoyuer. Thanks for your attention.

This is probably due to the inability to use torch.utils.checkpoint when find_unused_parameter is set to True.

wutongshenqiu avatar Aug 09 '22 05:08 wutongshenqiu

Hi @wutongshenqiu, Thanks!

I also found this solution and tried turning off it by setting the find_unused_parameter = False. However, it cannot work for distillation cases since the teacher in the distillation is frozen and has parameters that have no grad. Setting False to unused_parameter will report errors on non-used parameters.

Mrxiaoyuer avatar Aug 09 '22 17:08 Mrxiaoyuer

Indeed, it is a tricky problem. Sorry we don't have a good solution right now either, and we will add this requirement to the backlog, but not sure when we will be able to solve it.

If you have any progress on it, any help or advice is very welcome!

wutongshenqiu avatar Aug 10 '22 01:08 wutongshenqiu