mmpretrain icon indicating copy to clipboard operation
mmpretrain copied to clipboard

[Feature] Switch aug hook

Open okotaku opened this issue 2 years ago • 9 comments

Motivation

Support stop aug hook. This work is based on mmdet yolox mode switch hook.

Data Augmentation Revisited: Rethinking the Distribution Gap between Clean and Augmented Data

Use cases (Optional)

_base_ = 'resnet18_8xb16_cifar10.py'

_deprecation_ = dict(
    expected='resnet18_8xb16_cifar10.py',
    reference='https://github.com/open-mmlab/mmclassification/pull/508',
)
model = dict(
    head=dict(loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True)),
    train_cfg=dict(
        augments=dict(type='BatchMixup', alpha=1., num_classes=10, prob=1.)))

custom_hooks = [
    dict(
        type='SwitchTrainAugHook',
        action_epoch=2,
        augments_cfg=None,
        loss=dict(type='CrossEntropyLoss', loss_weight=1.0)),
    dict(
        type='SwitchDataAugHook',
        action_epoch=2,
        pipeline=None,
        skip_type_keys=('AutoAugment', 'RandAugment'))
]
runner = dict(type='EpochBasedRunner', max_epochs=3)

Checklist

Before PR:

  • [x] Pre-commit or other linting tools are used to fix the potential lint issues.
  • [x] Bug fixes are fully covered by unit tests, the case that causes the bug should be added in the unit tests.
  • [x] The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  • [x] The documentation has been modified accordingly, like docstring or example tutorials.

After PR:

  • [x] If the modification has potential influence on downstream or other related projects, this PR should be tested with those projects, like MMDet or MMSeg.
  • [x] CLA has been signed and all committers have signed the CLA in this PR.

okotaku avatar Mar 10 '22 02:03 okotaku

Codecov Report

Merging #729 (9cd7d29) into dev (c1534f9) will decrease coverage by 0.08%. The diff coverage is 82.85%.

Impacted file tree graph

@@            Coverage Diff             @@
##              dev     #729      +/-   ##
==========================================
- Coverage   84.98%   84.89%   -0.09%     
==========================================
  Files         122      123       +1     
  Lines        7573     7642      +69     
  Branches     1304     1324      +20     
==========================================
+ Hits         6436     6488      +52     
- Misses        945      953       +8     
- Partials      192      201       +9     
Flag Coverage Δ
unittests 84.85% <82.85%> (-0.05%) :arrow_down:

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmcls/datasets/pipelines/compose.py 83.33% <80.00%> (-1.86%) :arrow_down:
mmcls/core/hook/switch_augments_hook.py 83.05% <83.05%> (ø)
mmcls/core/hook/__init__.py 100.00% <100.00%> (ø)
mmcls/models/utils/helpers.py 87.50% <0.00%> (-12.50%) :arrow_down:
mmcls/utils/setup_env.py 95.45% <0.00%> (-4.55%) :arrow_down:
mmcls/datasets/builder.py 90.90% <0.00%> (-1.52%) :arrow_down:

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update c1534f9...9cd7d29. Read the comment docs.

codecov[bot] avatar Mar 10 '22 02:03 codecov[bot]

I notice that it stops the data augment such as auto-augment in the paper. And the hook in mmdet just implements this function. Maybe, we can implement a 'stop_data_aug_hook'.

Ezra-Yu avatar Mar 11 '22 09:03 Ezra-Yu

You are correct. However, the yolox hook is used in pairs with a multi image dataset wrapper. Specifically, the following part is required. https://github.com/open-mmlab/mmdetection/blob/master/mmdet/datasets/dataset_wrappers.py#L397-L415

It might be nice to have this kind of functionality in Compose. I will try to implement it, so please let me know if you have any suggestions.

okotaku avatar Mar 12 '22 13:03 okotaku

In my opinion, we can add this feature in two hooks, 'stop_train_aug_hook' and 'stop_data_aug_hook', since the train_aug and data_aug are different and separated in mmcls. 'stop_train_aug_hook' is what you have implemented and 'stop_data_aug_hook' is what mmdet has implemented.

What about using 'switch' instead of 'stop', which we can change from pipeline=[A, B, C, D] to pipeline=[A, None, C_, D], which has the effect of 'stop'.

Ezra-Yu avatar Mar 14 '22 04:03 Ezra-Yu

What about using 'switch' instead of 'stop', which we can change from pipeline=[A, B, C, D] to pipeline=[A, None, C_, D], which has the effect of 'stop'.

That is certainly true. I was wondering if I could have both switch and stop, is that redundant?

okotaku avatar Mar 14 '22 05:03 okotaku

That is certainly true. I was wondering if I could have both switch and stop, is that redundant?

Yes. Switch and stop only need one. I prefer 'switch' since it has more powerful functions, including 'stop'.

class SwitchTrainAugmentsHook(Hook)
'''
    action_epoch(int) : switch train aug at the action_epochth epoch 
    augs(List[dict] | None): the new train augments
'''
    def __init__(self,  action_epoch,augs):

what do you think about it?

the stop will be:

class StopTrainAugmentsHook(Hook)
'''
    start_or_stop_epoch(int) : 
    augs(List[str]): the augments to be stopped
'''
    def __init__(self,  start_or_stop_epoch,stop_aug_names):

Ezra-Yu avatar Mar 14 '22 05:03 Ezra-Yu

I think train_aug is fine with that.

I made a mock with the latest push. Especially for data_aug, I have left the stop function in the implementation, how do you like it?

okotaku avatar Mar 14 '22 06:03 okotaku

I think train_aug is fine with that.

I made a mock with the latest push. Especially for data_aug, I have left the stop function in the implementation,

look good!

Ezra-Yu avatar Mar 14 '22 07:03 Ezra-Yu

SwitchDataAugHook is a little dirty for dataset wrappers. Wish we had a better solution, but...

okotaku avatar Mar 14 '22 08:03 okotaku