mmpretrain icon indicating copy to clipboard operation
mmpretrain copied to clipboard

[Feature] Add `MultiTaskDataset` to support multi-task training.

Open mzr1996 opened this issue 2 years ago • 3 comments

Motivation

To support using a single backbone to perform multiple classification tasks.

Modification

This PR is one part of the multi-task support plan, and it depends on #675 to build a network.

BC-breaking (Optional)

No

Use cases

Here is a detailed multi-task support design. First, the multi-task means using one backbone and multiple heads to do classification on an image with multiple kinds of labels.

Dataset

The current multi-task requires full labels on every image, which means you cannot use partial-labeled samples to train the multi-task model.

To create a multi-task dataset, you can use the MultiTaskDataset class and prepare an annotation file. Here is a brief example:

The annotation json file example

{
  "metainfo": {
    "tasks":
      [
        {"name": "gender",
         "type": "single-label",
         "categories": ["male", "female"]},
        {"name": "wear",
         "type": "multi-label",
         "categories": ["shirt", "coat", "jeans", "pants"]}
      ]
  },
  "data_list": [
    {
      "img_path": "a.jpg",
      "gender_img_label": 0,
      "wear_img_label": [1, 0, 1, 0]
    },
    {
      "img_path": "b.jpg",
      "gender_img_label": 1,
      "wear_img_label": [0, 1, 0, 1]
    },
    ...
  ]
}

The detailed usage and example of the MultiTaskDataset can be found here

And here is a script to use the CIFAR10 dataset to generate an example multi-task dataset, just run it in the data folder. And here is the file structure.

data/
├── cifar10
│   ├── images
│   ├── multi-task-test.json
│   └── multi-task-train.json

And here is an example config to train on the multi-task dataset.

# Save as `configs/resnet/multi-task-demo.py`
_base_ = ['../_base_/schedules/cifar10_bs128.py', '../_base_/default_runtime.py']

# model settings
model = dict(
    type='ImageClassifier',
    backbone=dict(type='ResNet_CIFAR', depth=18),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='MultiTaskClsHead',                                    # <- Head config, depends on #675
        sub_heads={
            'task1': dict(type='LinearClsHead', num_classes=6),
            'task2': dict(type='LinearClsHead', num_classes=6),
        },
        common_cfg=dict(
            in_channels=512,
            loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
        ),
    ),
)

# dataset settings
dataset_type = 'MultiTaskDataset'
img_norm_cfg = dict(
    mean=[125.307, 122.961, 113.8575],
    std=[51.5865, 50.847, 51.255],
    to_rgb=False)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='RandomCrop', size=32, padding=4),
    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='FormatMultiTaskLabels'),                             # <- Use this to replace `ToTensor`.
    dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='Collect', keys=['img'])
]
data = dict(
    samples_per_gpu=16,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        data_root='data/cifar10',
        ann_file='multi-task-train.json',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        data_root='data/cifar10',
        ann_file='multi-task-test.json',
        pipeline=test_pipeline,
        test_mode=True),
    test=dict(
        type=dataset_type,
        data_root='data/cifar10',
        ann_file='multi-task-test.json',
        pipeline=test_pipeline,
        test_mode=True))

evaluation = dict(metric_options={
    'task1': dict(topk=(1, )),                # <- Specify different metric options for different tasks.
    'task2': dict(topk=(1, 3)),
})

Then, we can train the dataset by python tools/train.py configs/resnet/multi-task-demo.py

2022-04-29 18:25:37,968 - mmcls - INFO - workflow: [('train', 1)], max: 200 epochs
2022-04-29 18:25:37,968 - mmcls - INFO - Checkpoints will be saved to /home/work_dirs/multi-task-demo by HardDiskBackend.
2022-04-29 18:25:42,280 - mmcls - INFO - Epoch [1][100/2813]    lr: 1.000e-01, eta: 6:43:27, time: 0.043, data_time: 0.021, memory: 329, task1_loss: 1.7489, task2_loss: 1.6522, loss: 3.4011
...
2022-04-29 18:26:24,813 - mmcls - INFO - Saving checkpoint at 1 epochs
2022-04-29 18:26:26,951 - mmcls - INFO - Epoch(val) [1][313]    task1_accuracy_top-1: 62.7000, task2_accuracy_top-1: 65.6800, task2_accuracy_top-3: 96.4800

Checklist

Before PR:

  • [x] Pre-commit or other linting tools are used to fix the potential lint issues.
  • [x] Bug fixes are fully covered by unit tests, the case that causes the bug should be added in the unit tests.
  • [x] The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  • [x] The documentation has been modified accordingly, like docstring or example tutorials.

After PR:

  • [x] If the modification has potential influence on downstream or other related projects, this PR should be tested with those projects, like MMDet or MMSeg.
  • [x] CLA has been signed and all committers have signed the CLA in this PR.

mzr1996 avatar Apr 29 '22 10:04 mzr1996

Codecov Report

Merging #808 (21b0a38) into dev (59292b3) will increase coverage by 0.03%. The diff coverage is 86.72%.

@@            Coverage Diff             @@
##              dev     #808      +/-   ##
==========================================
+ Coverage   87.02%   87.05%   +0.03%     
==========================================
  Files         130      131       +1     
  Lines        8538     8739     +201     
  Branches     1468     1512      +44     
==========================================
+ Hits         7430     7608     +178     
- Misses        888      895       +7     
- Partials      220      236      +16     
Flag Coverage Δ
unittests 86.97% <86.72%> (+0.03%) :arrow_up:

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmcls/datasets/multi_label.py 89.74% <71.42%> (+0.85%) :arrow_up:
mmcls/datasets/multi_task.py 85.71% <85.71%> (ø)
mmcls/datasets/pipelines/formatting.py 54.95% <94.73%> (+11.47%) :arrow_up:
mmcls/datasets/__init__.py 100.00% <100.00%> (ø)
mmcls/datasets/base_dataset.py 99.00% <100.00%> (+0.03%) :arrow_up:
mmcls/datasets/pipelines/__init__.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update bce95b9...21b0a38. Read the comment docs.

codecov[bot] avatar Apr 29 '22 10:04 codecov[bot]

Any update for this?

JihwanEom avatar Jun 06 '22 13:06 JihwanEom

FormatMultiTaskLabels Where can I find the above class? Thank you.

iamweiweishi avatar Aug 08 '22 10:08 iamweiweishi

@mzr1996 @Ezra-Yu This is great, thanks for this, i did only half of the job in #675 !

I have merged both MultiClsHead and MultiTaskDataset in my repo https://github.com/piercus/mmclassification/tree/multi-task

What are the next steps for this : (1) Have you decide if this feature makes sense in the core mmcls ? (2) Is there some clean up to do ? (3) do we require more testing / Examples ?

Thank you for your help

piercus avatar Sep 22 '22 10:09 piercus

Closed since #1229 is merged

mzr1996 avatar Jan 12 '23 00:01 mzr1996