mmrazor
mmrazor copied to clipboard
[Bug] Checkpoint fails with "with_cp=True" in mmdet swin-transformer distillation.
Describe the bug
A clear and concise description of what the bug is.
[Hi Developers, Thanks for your efforts. I was trying to distill swint-ransformers and use the (with_cp=True) to train with larger batch size. But it turns out some errors as attached below.]
To Reproduce
[python ./tools/mmdet/train_mmdet.py ./configs/distill/cwd/my_config.py]
Post related information
- The output of
pip list | grep "mmcv\|mmrazor\|^torch"
[here] - Your config file if you modified it or created a new one.
[_base_ = [
'../../_base_/datasets/mmdet/coco_detection.py',
'../../_base_/schedules/mmdet/schedule_1x.py',
'../../_base_/mmdet_runtime.py'
]
pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth' # noqa
student = dict(
type='mmdet.MaskRCNN',
backbone=dict(
type='SwinTransformer',
embed_dims=96,
depths=[2, 2, 18, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
patch_norm=True,
out_indices=(0, 1, 2, 3),
with_cp=True,
convert_weights=True,
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
neck=dict(
type='FPN',
in_channels=[96, 192, 384, 768],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
roi_head=dict(
type='StandardRoIHead',
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=-1,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=2000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False)),
test_cfg=dict(
rpn=dict(
nms_pre=1000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100,
mask_thr_binary=0.5)))
teacher_ckpt = 'https://github.com/Mrxiaoyuer/coco_checkpoints/releases/download/v0.0.2/swinl_epoch_24.pth' # noqa
teacher = dict(
type='mmdet.MaskRCNN',
init_cfg=dict(type='Pretrained', checkpoint=teacher_ckpt),
backbone=dict(
type='SwinTransformer',
embed_dims=192,
depths=[2, 2, 18, 2],
num_heads=[6, 12, 24, 48],
window_size=7,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
patch_norm=True,
out_indices=(0, 1, 2, 3),
with_cp=False,
convert_weights=True,
init_cfg=None),
neck=dict(
type='FPN',
in_channels=[192, 384, 768, 1536],
out_channels=256,
start_level=0,
add_extra_convs='on_output',
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
roi_head=dict(
type='StandardRoIHead',
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=-1,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=2000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False)),
test_cfg=dict(
rpn=dict(
nms_pre=1000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100,
mask_thr_binary=0.5)))
# algorithm setting
algorithm = dict(
type='GeneralDistill',
architecture=dict(
type='MMDetArchitecture',
model=student,
),
distiller=dict(
type='SingleTeacherDistiller',
teacher=teacher,
teacher_trainable=False,
components=[
dict(
student_module='neck.fpn_convs.3.conv',
teacher_module='neck.fpn_convs.3.conv',
losses=[
dict(
type='ChannelWiseDivergence',
name='loss_cwd_cls_head',
tau=1,
loss_weight=5,
)
])
]),
)
find_unused_parameters = True
# optimizer
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.0001,
betas=(0.9, 0.999),
weight_decay=0.05,
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))
lr_config = dict(warmup_iters=1000, step=[8, 11])
runner = dict(max_epochs=12)
- Your train log file if you meet the problem during training.
[loading annotations into memory...
Done (t=51.46s)
creating index...
Done (t=51.98s)
creating index...
Done (t=51.93s)
creating index...
Done (t=51.97s)
creating index...
index created!
index created!
index created!
index created!
loading annotations into memory...
loading annotations into memory...
loading annotations into memory...
loading annotations into memory...
Done (t=1.01s)
creating index...
Done (t=1.24s)
creating index...
Done (t=1.25s)
creating index...
Done (t=1.02s)
creating index...
index created!
index created!
index created!
index created!
2022-08-08 20:44:14,703 - mmdet - INFO - Start running, host: root@da8f74b40bf146bf925ac3a4349ccef80000Z0, work_dir: /mnt/azureml/cr/j/89a4752b1bff40f9951561de40f847a6/cap/data-capability/wd/INPUT_3324fea54867408dbec17da2dfe673f2_data/fuxun/output/fuxun-mmrazor-cwd-swin-l-s_1x_adamw_2xbs/cwd_cls_head_swin-l-s_1x_coco_adamw_2xbs
2022-08-08 20:44:14,704 - mmdet - INFO - Hooks will be executed in the following order:
before_run:
(VERY_HIGH ) StepLrUpdaterHook
(ABOVE_NORMAL) Fp16OptimizerHook
(NORMAL ) CheckpointHook
(LOW ) DistEvalHook
(VERY_LOW ) TextLoggerHook
before_train_epoch:
(VERY_HIGH ) StepLrUpdaterHook
(NORMAL ) NumClassCheckHook
(NORMAL ) DistSamplerSeedHook
(LOW ) IterTimerHook
(LOW ) DistEvalHook
(VERY_LOW ) TextLoggerHook
before_train_iter:
(VERY_HIGH ) StepLrUpdaterHook
(LOW ) IterTimerHook
(LOW ) DistEvalHook
after_train_iter:
(ABOVE_NORMAL) Fp16OptimizerHook
(NORMAL ) CheckpointHook
(LOW ) IterTimerHook
(LOW ) DistEvalHook
(VERY_LOW ) TextLoggerHook
after_train_epoch:
(NORMAL ) CheckpointHook
(LOW ) DistEvalHook
(VERY_LOW ) TextLoggerHook
before_val_epoch:
(NORMAL ) NumClassCheckHook
(NORMAL ) DistSamplerSeedHook
(LOW ) IterTimerHook
(VERY_LOW ) TextLoggerHook
before_val_iter: (LOW ) IterTimerHook
after_val_iter: (LOW ) IterTimerHook
after_val_epoch: (VERY_LOW ) TextLoggerHook
after_run: (VERY_LOW ) TextLoggerHook
2022-08-08 20:44:14,704 - mmdet - INFO - workflow: [('train', 1)], max: 12 epochs
2022-08-08 20:44:14,710 - mmdet - INFO - Checkpoints will be saved to /mnt/azureml/cr/j/89a4752b1bff40f9951561de40f847a6/cap/data-capability/wd/INPUT_3324fea54867408dbec17da2dfe673f2_data/fuxun/output/fuxun-mmrazor-cwd-swin-l-s_1x_adamw_2xbs/cwd_cls_head_swin-l-s_1x_coco_adamw_2xbs by HardDiskBackend.
/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1623448265233/work/c10/core/TensorImpl.h:1156.)
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1623448265233/work/c10/core/TensorImpl.h:1156.)
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1623448265233/work/c10/core/TensorImpl.h:1156.)
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1623448265233/work/c10/core/TensorImpl.h:1156.)
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
Traceback (most recent call last):
File "tools/mmdet/train_mmdet.py", line 230, in forward
function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple checkpoint
functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 323 has been marked as ready twice. This means that multiple autograd engine hooks have fired for this particular parameter during this iteration. You can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print parameter names for further debugging.]
4. Other code you modified in the mmrazor
folder.
[here]
Additional context
Add any other context about the problem here.
[(1) I tried add (with_cp=True) in the r101_r50 cwd distillation and it works fine. So this is probably related to the swin transformer structure (2) I also tried (with_cp=True) by training swin-transformer alone in the mmdetection repo and it works fine. ]
Hi, Mrxiaoyuer. Thanks for your attention.
This is probably due to the inability to use torch.utils.checkpoint when find_unused_parameter
is set to True.
Hi @wutongshenqiu, Thanks!
I also found this solution and tried turning off it by setting the find_unused_parameter = False. However, it cannot work for distillation cases since the teacher in the distillation is frozen and has parameters that have no grad. Setting False to unused_parameter will report errors on non-used parameters.
Indeed, it is a tricky problem. Sorry we don't have a good solution right now either, and we will add this requirement to the backlog, but not sure when we will be able to solve it.
If you have any progress on it, any help or advice is very welcome!