mmsegmentation icon indicating copy to clipboard operation
mmsegmentation copied to clipboard

Error while trying to train custom dataset -__init__() error

Open cspearl opened this issue 2 years ago • 6 comments

Hi, any help would be appreciated. I am trying to train a custom dataset from scratch. I cloned the repo 2 days back and did the setup for custom datasets as was mentioned in docs. First model which I am trying to train is CCNet. But I am facing this error and am not being able to train at all -

Traceback (most recent call last): File "/usr/local/lib/python3.7/dist-packages/mmcv/utils/registry.py", line 69, in build_from_cfg return obj_cls(**args) File "/content/drive/MyDrive/openmm/mmsegmentation/mmseg/datasets/amazon.py", line 21, in init **kwargs) TypeError: init() got an unexpected keyword argument 'times'

cspearl avatar Aug 03 '22 07:08 cspearl

Would you mind proving the implementation of your dataset and the documentation link you referred to?

MeowZheng avatar Aug 03 '22 07:08 MeowZheng

Would you mind proving the implementation of your dataset and the documentation link you referred to?

I used this instructions while implementing my dataset - https://github.com/MengzhangLI/mmsegmentation/blob/add_doc_customization_dataset/docs/en/tutorials/customize_datasets.md

This is the code of the file where the custom dataset is defined -

from .builder import DATASETS from .custom import CustomDataset

@DATASETS.register_module() class AmazonDataset(CustomDataset):

CLASSES = ('background', 'label1', 'label2', 'label3',
           'label4')

PALETTE = [[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3],
           [4, 4, 4]]


def __init__(self, **kwargs):
    super(AmazonDataset, self).__init__(
        img_suffix='.jpg',
        seg_map_suffix='.png',
        reduce_zero_label=False, # reduce_zero_label is False because label 0 is background (first one in CLASSES above)
        **kwargs)

cspearl avatar Aug 03 '22 13:08 cspearl

I think you might print kwargs after def __init__(self, **kwargs) before super to check the times in kwargs

MeowZheng avatar Aug 03 '22 16:08 MeowZheng

I think you might print kwargs after def __init__(self, **kwargs) before super to check the times in kwargs

{'times': 40000, ...} this gets printed as the first key-value pair of kwargs - which I believe is denoting number of iterations, what should I change?

cspearl avatar Aug 04 '22 06:08 cspearl

Do you use RepeatDataet in dataset config? https://github.com/open-mmlab/mmsegmentation/blob/4eaa8e69191cc293b64dafe47f1f88a7d468c93c/mmseg/datasets/dataset_wrappers.py#L176

MeowZheng avatar Aug 05 '22 06:08 MeowZheng

Do you use RepeatDataet in dataset config?

https://github.com/open-mmlab/mmsegmentation/blob/4eaa8e69191cc293b64dafe47f1f88a7d468c93c/mmseg/datasets/dataset_wrappers.py#L176

No I have not used RepeatDataset in config My config is -

dataset_type = 'AmazonDataset' data_root = 'data/amazon'

img_norm_cfg = dict( mean=[101.951, 117.161, 62.96], std=[19.019, 11.592, 12.754], to_rgb=True)

img_scale = (512, 512) crop_size = (512, 512) train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_semantic_seg']) ] test_pipeline = [ dict(type='LoadImageFromFile'), dict( type='MultiScaleFlipAug', img_scale=img_scale, flip=False, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']) ]) ]

data = dict( samples_per_gpu=2, workers_per_gpu=2, train=dict( type=dataset_type, times=40000, dataset=dict( type=dataset_type, data_root=data_root, img_dir='img/train', ann_dir='label/train', pipeline=train_pipeline)), val=dict( type=dataset_type, data_root=data_root, img_dir='img/val', ann_dir='label/val', pipeline=test_pipeline), test=dict( type=dataset_type, data_root=data_root, img_dir='img/val', ann_dir='label/val', pipeline=test_pipeline))

cspearl avatar Aug 05 '22 08:08 cspearl

just remove times=40000 from your config

data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
times=40000,
...

MeowZheng avatar Aug 08 '22 10:08 MeowZheng

Yes the training part is now working thanks for the help!

cspearl avatar Aug 08 '22 17:08 cspearl