mmpretrain
mmpretrain copied to clipboard
[Feature] Add `MultiTaskDataset` to support multi-task training.
Motivation
To support using a single backbone to perform multiple classification tasks.
Modification
This PR is one part of the multi-task support plan, and it depends on #675 to build a network.
BC-breaking (Optional)
No
Use cases
Here is a detailed multi-task support design. First, the multi-task means using one backbone and multiple heads to do classification on an image with multiple kinds of labels.
Dataset
The current multi-task requires full labels on every image, which means you cannot use partial-labeled samples to train the multi-task model.
To create a multi-task dataset, you can use the MultiTaskDataset
class and prepare an annotation file. Here is a brief example:
The annotation json file example
{
"metainfo": {
"tasks":
[
{"name": "gender",
"type": "single-label",
"categories": ["male", "female"]},
{"name": "wear",
"type": "multi-label",
"categories": ["shirt", "coat", "jeans", "pants"]}
]
},
"data_list": [
{
"img_path": "a.jpg",
"gender_img_label": 0,
"wear_img_label": [1, 0, 1, 0]
},
{
"img_path": "b.jpg",
"gender_img_label": 1,
"wear_img_label": [0, 1, 0, 1]
},
...
]
}
The detailed usage and example of the MultiTaskDataset
can be found here
And here is a script to use the CIFAR10 dataset to generate an example multi-task dataset, just run it in the data
folder. And here is the file structure.
data/
├── cifar10
│ ├── images
│ ├── multi-task-test.json
│ └── multi-task-train.json
And here is an example config to train on the multi-task dataset.
# Save as `configs/resnet/multi-task-demo.py`
_base_ = ['../_base_/schedules/cifar10_bs128.py', '../_base_/default_runtime.py']
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='ResNet_CIFAR', depth=18),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='MultiTaskClsHead', # <- Head config, depends on #675
sub_heads={
'task1': dict(type='LinearClsHead', num_classes=6),
'task2': dict(type='LinearClsHead', num_classes=6),
},
common_cfg=dict(
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
),
)
# dataset settings
dataset_type = 'MultiTaskDataset'
img_norm_cfg = dict(
mean=[125.307, 122.961, 113.8575],
std=[51.5865, 50.847, 51.255],
to_rgb=False)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomCrop', size=32, padding=4),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='FormatMultiTaskLabels'), # <- Use this to replace `ToTensor`.
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
samples_per_gpu=16,
workers_per_gpu=2,
train=dict(
type=dataset_type,
data_root='data/cifar10',
ann_file='multi-task-train.json',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root='data/cifar10',
ann_file='multi-task-test.json',
pipeline=test_pipeline,
test_mode=True),
test=dict(
type=dataset_type,
data_root='data/cifar10',
ann_file='multi-task-test.json',
pipeline=test_pipeline,
test_mode=True))
evaluation = dict(metric_options={
'task1': dict(topk=(1, )), # <- Specify different metric options for different tasks.
'task2': dict(topk=(1, 3)),
})
Then, we can train the dataset by python tools/train.py configs/resnet/multi-task-demo.py
2022-04-29 18:25:37,968 - mmcls - INFO - workflow: [('train', 1)], max: 200 epochs
2022-04-29 18:25:37,968 - mmcls - INFO - Checkpoints will be saved to /home/work_dirs/multi-task-demo by HardDiskBackend.
2022-04-29 18:25:42,280 - mmcls - INFO - Epoch [1][100/2813] lr: 1.000e-01, eta: 6:43:27, time: 0.043, data_time: 0.021, memory: 329, task1_loss: 1.7489, task2_loss: 1.6522, loss: 3.4011
...
2022-04-29 18:26:24,813 - mmcls - INFO - Saving checkpoint at 1 epochs
2022-04-29 18:26:26,951 - mmcls - INFO - Epoch(val) [1][313] task1_accuracy_top-1: 62.7000, task2_accuracy_top-1: 65.6800, task2_accuracy_top-3: 96.4800
Checklist
Before PR:
- [x] Pre-commit or other linting tools are used to fix the potential lint issues.
- [x] Bug fixes are fully covered by unit tests, the case that causes the bug should be added in the unit tests.
- [x] The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
- [x] The documentation has been modified accordingly, like docstring or example tutorials.
After PR:
- [x] If the modification has potential influence on downstream or other related projects, this PR should be tested with those projects, like MMDet or MMSeg.
- [x] CLA has been signed and all committers have signed the CLA in this PR.
Codecov Report
Merging #808 (21b0a38) into dev (59292b3) will increase coverage by
0.03%
. The diff coverage is86.72%
.
@@ Coverage Diff @@
## dev #808 +/- ##
==========================================
+ Coverage 87.02% 87.05% +0.03%
==========================================
Files 130 131 +1
Lines 8538 8739 +201
Branches 1468 1512 +44
==========================================
+ Hits 7430 7608 +178
- Misses 888 895 +7
- Partials 220 236 +16
Flag | Coverage Δ | |
---|---|---|
unittests | 86.97% <86.72%> (+0.03%) |
:arrow_up: |
Flags with carried forward coverage won't be shown. Click here to find out more.
Impacted Files | Coverage Δ | |
---|---|---|
mmcls/datasets/multi_label.py | 89.74% <71.42%> (+0.85%) |
:arrow_up: |
mmcls/datasets/multi_task.py | 85.71% <85.71%> (ø) |
|
mmcls/datasets/pipelines/formatting.py | 54.95% <94.73%> (+11.47%) |
:arrow_up: |
mmcls/datasets/__init__.py | 100.00% <100.00%> (ø) |
|
mmcls/datasets/base_dataset.py | 99.00% <100.00%> (+0.03%) |
:arrow_up: |
mmcls/datasets/pipelines/__init__.py | 100.00% <100.00%> (ø) |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact)
,ø = not affected
,? = missing data
Powered by Codecov. Last update bce95b9...21b0a38. Read the comment docs.
Any update for this?
FormatMultiTaskLabels Where can I find the above class? Thank you.
@mzr1996 @Ezra-Yu This is great, thanks for this, i did only half of the job in #675 !
I have merged both MultiClsHead and MultiTaskDataset in my repo https://github.com/piercus/mmclassification/tree/multi-task
What are the next steps for this : (1) Have you decide if this feature makes sense in the core mmcls ? (2) Is there some clean up to do ? (3) do we require more testing / Examples ?
Thank you for your help
Closed since #1229 is merged