unidet3d icon indicating copy to clipboard operation
unidet3d copied to clipboard

Train and Test on different datasets

Open azztt opened this issue 6 months ago • 6 comments

Hi, Thanks for sharing this work. Is it possible to train on a few datasets and test on different ones? (eg. train on scannet++ and arkitscenes, while test on all the 6?) I know few labels might be missing. but any other issue?

Thanks

azztt avatar Jun 27 '25 07:06 azztt

Yes, I think should be fine. The released model is trained on 6 datasets, but can be tested on either all 6, each subset of them, or a new one. For practical reasons you may need to map set of training labels to your set of your test labels.

filaPro avatar Jun 27 '25 07:06 filaPro

Thanks a lot for the response. I'm more interested in a scenario where I train with say, scannetpp and arkitscenes, while test on scannet (having 4 classes not present in the training data). For this scenario, I changed the config file as follows (original config on RHS):

model: Image

train dataloader: Image

validation dataloader: Image

I added a test dataloader same as the original validation dataloader (so now its different from the current validation one) and separated the evaluators for train and test time: Image

The training phase completed successfully, but when I try to test with the same config, I get the following error:

Traceback (most recent call last):
  File "/root/anweshan/mantis/unidet3d/tools/test.py", line 150, in <module>
    main()
  File "/root/anweshan/mantis/unidet3d/tools/test.py", line 146, in main
    runner.test()
  File "/opt/conda/lib/python3.10/site-packages/mmengine/runner/runner.py", line 1823, in test
    metrics = self.test_loop.run()  # type: ignore
  File "/opt/conda/lib/python3.10/site-packages/mmengine/runner/loops.py", line 435, in run
    self.run_iter(idx, data_batch)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/mmengine/runner/loops.py", line 454, in run_iter
    outputs = self.runner.model.test_step(data_batch)
  File "/opt/conda/lib/python3.10/site-packages/mmengine/model/base_model/base_model.py", line 145, in test_step
    return self._run_forward(data, mode='predict')  # type: ignore
  File "/opt/conda/lib/python3.10/site-packages/mmengine/model/base_model/base_model.py", line 346, in _run_forward
    results = self(**data, mode=mode)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/mmdet3d/models/detectors/base.py", line 86, in forward
    return self.predict(inputs, data_samples, **kwargs)
  File "/root/anweshan/mantis/unidet3d/unidet3d/unidet3d.py", line 462, in predict
    x = self.decoder(x, sp_centers, datasets_names)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/anweshan/mantis/unidet3d/unidet3d/encoder.py", line 220, in forward
    self._forward_head(feats, sp_centers, datasets_names)
  File "/root/anweshan/mantis/unidet3d/unidet3d/encoder.py", line 191, in _forward_head
    idx = self.datasets.index(datasets_names[i])
ValueError: None is not in list

A brief debug lets me know that the model is loaded with some parameter called datasets_names which are set to the training ones, while when a data point is being loaded which is from say, scannet, it is not able to the find it in this method: Image where it tries match the test point's dataset name with those in self.decoder.datasets which, in this case are: scannetpp and arkitscenes

Would appreciate if I can get some help with this from the authors or anyone who might have faced similar issue and solved it.

azztt avatar Jun 30 '25 08:06 azztt

Is it ok that your test_evaluator has different lengths of datasets and datasets_classes in line 766?

filaPro avatar Jun 30 '25 08:06 filaPro

Actually, the lengths for both are same. I assume you mention that because of the highlight in red. If that's the case, it doesn't mean that its deleted, its just showing a diff from both files. Currently that part looks like this:

Image

azztt avatar Jun 30 '25 08:06 azztt

Hey! I believe you should use a config like this. It computes metrics on ScanNet across all 18 classes, but you can just remove any nonexistent classes when calculating the average metrics. This seems like the easiest way (This config hasn't been tested and may contain errors, sorry)

_base_ = ['mmdet3d::_base_/default_runtime.py']
custom_imports = dict(imports=['unidet3d'])


classes_scannet = ['cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf',
                    'picture', 'counter', 'desk', 'curtain', 'refrigerator', 'showercurtrain',
                    'toilet', 'sink', 'bathtub', 'otherfurniture']

classes_scannetpp = ['table', 'door', 'ceiling lamp', 'cabinet', 'blinds', 'curtain', 'chair', 'storage cabinet', 'office chair', 'bookshelf', 'whiteboard', 'window', 'box', 
                     'monitor', 'shelf', 'heater', 'kitchen cabinet', 'sofa', 'bed', 'trash can', 'book', 'plant', 'blanket', 'tv', 'computer tower', 'refrigerator', 'jacket', 
                     'sink', 'bag', 'picture', 'pillow', 'towel', 'suitcase', 'backpack', 'crate', 'keyboard', 'rack', 'toilet', 'printer', 'poster', 'painting', 'microwave', 'shoes', 
                     'socket', 'bottle', 'bucket', 'cushion', 'basket', 'shoe rack', 'telephone', 'file folder', 'laptop', 'plant pot', 'exhaust fan', 'cup', 'coat hanger', 'light switch', 
                     'speaker', 'table lamp', 'kettle', 'smoke detector', 'container', 'power strip', 'slippers', 'paper bag', 'mouse', 'cutting board', 'toilet paper', 'paper towel', 
                     'pot', 'clock', 'pan', 'tap', 'jar', 'soap dispenser', 'binder', 'bowl', 'tissue box', 'whiteboard eraser', 'toilet brush', 'spray bottle', 'headphones', 'stapler', 'marker']

classes_arkitscenes = ['cabinet', 'refrigerator', 'shelf', 'stove', 'bed',
                        'sink', 'washer', 'toilet', 'bathtub', 'oven',
                        'dishwasher', 'fireplace', 'stool', 'chair', 'table',
                        'tv_monitor', 'sofa']

# model settings
num_channels=32
voxel_size=0.02

model = dict(
    type='UniDet3D',
    data_preprocessor=dict(type='Det3DDataPreprocessor_'),
    in_channels=6,
    num_channels=num_channels,
    voxel_size=voxel_size,
    min_spatial_shape=128,
    query_thr=3000,
    bbox_by_mask=[True, False, False],
    target_by_distance=[False, True, True],
    use_superpoints=[True, False, False],
    fast_nms=[True, True, None],
    backbone=dict(
        type='SpConvUNet',
        num_planes=[num_channels * (i + 1) for i in range(5)],
        return_blocks=True),
    decoder=dict(
        type='UniDet3DEncoder',
        num_layers=6,
        datasets_classes=[classes_scannet,
                          classes_scannetpp, classes_arkitscenes],
        in_channels=num_channels,
        d_model=256,
        num_heads=8,
        hidden_dim=1024,
        dropout=0.0,
        activation_fn='gelu',
        datasets=['scannet', 'scannetpp', 'arkitscenes'],
        angles=[False, False, True]),
    criterion=dict(
        type='UniDet3DCriterion',
            datasets=['scannet', 'scannetpp', 'arkitscenes'],
            datasets_weights=[1, 1, 1],
            bbox_loss_simple=dict(
                type='UniDet3DAxisAlignedIoULoss',
                mode='diou',
                reduction='none'),
            bbox_loss_rotated=dict(
                type='UniDet3DRotatedIoU3DLoss',
                mode='diou',
                reduction='none'),
            matcher=dict(
                type='UniMatcher',
                costs=[
                    dict(type='QueryClassificationCost', weight=0.5),
                    dict(type='BboxCostJointTraining', 
                            weight=2.0,
                            loss_simple=dict(
                                type='UniDet3DAxisAlignedIoULoss',
                                mode='diou',
                                reduction='none'),
                            loss_rotated=dict(
                                type='UniDet3DRotatedIoU3DLoss',
                                mode='diou',
                                reduction='none'))]),
            loss_weight=[0.5, 1.0],
            non_object_weight=0.1,
            topk=[6, 3, 3],
            iter_matcher=True),
    train_cfg=dict(topk=6),
    test_cfg=dict(
        low_sp_thr=0.18,
        up_sp_thr=0.81,
        topk_insts=1000,
        score_thr=0,
        iou_thr=[0.5, 0.55, 0.55]))

# scannet dataset settings

metainfo_scannet = dict(classes=classes_scannet)
data_root_scannet = 'data/scannet/'

max_class_scannet = 20
dataset_type_scannet = 'ScanNetDetDataset'
data_prefix_scannet = dict(
    pts='points',
    pts_instance_mask='instance_mask',
    pts_semantic_mask='semantic_mask',
    sp_pts_mask='super_points')

train_pipeline_scannet = [
    dict(
        type='LoadPointsFromFile',
        coord_type='DEPTH',
        shift_height=False,
        use_color=True,
        load_dim=6,
        use_dim=[0, 1, 2, 3, 4, 5]),
    dict(
        type='LoadAnnotations3D_',
        with_bbox_3d=False,
        with_label_3d=False,
        with_mask_3d=True,
        with_seg_3d=True,
        with_sp_mask_3d=True),
    dict(type='GlobalAlignment', rotation_axis=2),
    dict(type='PointSegClassMapping'),
    dict(
        type='RandomFlip3D',
        sync_2d=False,
        flip_ratio_bev_horizontal=0.5,
        flip_ratio_bev_vertical=0.5),
    dict(
        type='GlobalRotScaleTrans',
        rot_range=[-3.14, 3.14],
        scale_ratio_range=[0.8, 1.2],
        translation_std=[0.1, 0.1, 0.1],
        shift_height=False),
    dict(
        type='NormalizePointsColor_',
        color_mean=[127.5, 127.5, 127.5]),
    dict(
        type='PointDetClassMappingScanNet',
        num_classes=max_class_scannet,
        stuff_classes=[0, 1]),
    dict(
        type='ElasticTransfrom',
        gran=[6, 20],
        mag=[40, 160],
        voxel_size=voxel_size,
        p=0.5),
    dict(
        type='Pack3DDetInputs_',
        keys=[
            'points', 'gt_labels_3d', 'pts_semantic_mask', 'pts_instance_mask',
            'sp_pts_mask', 'gt_sp_masks', 'elastic_coords'
        ])
]
test_pipeline_scannet = [
    dict(
        type='LoadPointsFromFile',
        coord_type='DEPTH',
        shift_height=False,
        use_color=True,
        load_dim=6,
        use_dim=[0, 1, 2, 3, 4, 5]),
    dict(
        type='LoadAnnotations3D_',
        with_bbox_3d=False,
        with_label_3d=False,
        with_mask_3d=True,
        with_seg_3d=True,
        with_sp_mask_3d=True),
    dict(type='GlobalAlignment', rotation_axis=2),
    dict(
        type='MultiScaleFlipAug3D',
        img_scale=(1333, 800),
        pts_scale_ratio=1,
        flip=False,
        transforms=[
            dict(
                type='NormalizePointsColor_',
                color_mean=[127.5, 127.5, 127.5])]),
    dict(type='Pack3DDetInputs_', keys=['points', 'sp_pts_mask'])
]

# scannetpp dataset settings
data_root_scannetpp = 'data/scannetpp/bins'
dataset_type_scannetpp = 'Scannetpp_'
data_prefix_scannetpp = dict(
    pts='points',
    pts_instance_mask='instance_mask',
    pts_semantic_mask='semantic_mask',
    sp_pts_mask='super_points_spt')

train_pipeline_scannetpp = [
    dict(
        type='LoadPointsFromFile',
        coord_type='DEPTH',
        shift_height=False,
        use_color=True,
        load_dim=6,
        use_dim=[0, 1, 2, 3, 4, 5]),
    dict(type='LoadAnnotations3D_',
         with_label_3d=True,
         with_bbox_3d=True,
         with_sp_mask_3d=True),
    dict(type='PointSample_', 
         num_points=200000),
    dict(
        type='RandomFlip3D',
        sync_2d=False,
        flip_ratio_bev_horizontal=0.5,
        flip_ratio_bev_vertical=0.5),
    dict(
        type='GlobalRotScaleTrans',
        rot_range=[0, 0],
        scale_ratio_range=[0.9, 1.1],
        translation_std=[0.1, 0.1, 0.1],
        shift_height=False),
    dict(
        type='NormalizePointsColor_',
        color_mean=[127.5, 127.5, 127.5]),
    dict(
        type='ElasticTransfrom',
        gran=[6, 20],
        mag=[40, 160],
        voxel_size=voxel_size,
        p=-1),
    dict(
        type='Pack3DDetInputs_',
        keys=['points', 'elastic_coords', 'gt_bboxes_3d', 
              'gt_labels_3d', 'sp_pts_mask'])
]
test_pipeline_scannetpp = [
    dict(
        type='LoadPointsFromFile',
        coord_type='DEPTH',
        shift_height=False,
        use_color=True,
        load_dim=6,
        use_dim=[0, 1, 2, 3, 4, 5]),
    dict(type='LoadAnnotations3D_',
         with_label_3d=False,
         with_bbox_3d=False,
         with_sp_mask_3d=True),
    dict(
        type='MultiScaleFlipAug3D',
        img_scale=(1333, 800),
        pts_scale_ratio=1,
        flip=False,
        transforms=[
            dict(type='PointSample_', num_points=200000),
            dict(
                type='NormalizePointsColor_',
                color_mean=[127.5, 127.5, 127.5])
        ]),
    dict(type='Pack3DDetInputs_', keys=['points', 'sp_pts_mask'])
]

# arkitscenes dataset settings
dataset_type_arkitscenes = 'ARKitScenesOfflineDataset'
data_root_arkitscenes = 'data/arkitscenes'
data_prefix_arkitscenes = dict(
    pts='offline_prepared_data',
    sp_pts_mask='super_points')

train_pipeline_arkitscenes = [
    dict(
        type='LoadPointsFromFile',
        coord_type='DEPTH',
        shift_height=False,
        use_color=True,
        load_dim=6,
        use_dim=[0, 1, 2, 3, 4, 5]),
    dict(type='LoadAnnotations3D_',
         with_label_3d=True,
         with_bbox_3d=True,
         with_sp_mask_3d=True),
    dict(type='PointSample_', num_points=100000),
    dict(
        type='DenormalizePointsColor',
        color_mean=[0, 0, 0],
        color_std=[255, 255, 255]),
    dict(
        type='NormalizePointsColor_',
        color_mean=[127.5, 127.5, 127.5]),
    dict(
        type='RandomFlip3D',
        sync_2d=False,
        flip_ratio_bev_horizontal=0.5,
        flip_ratio_bev_vertical=0.5),
    dict(
        type='GlobalRotScaleTrans',
        rot_range=[-0.5, 0.5],
        scale_ratio_range=[0.9, 1.1],
        translation_std=[0.1, 0.1, 0.1],
        shift_height=False),
    dict(
        type='ElasticTransfrom',
        gran=[6, 20],
        mag=[40, 160],
        voxel_size=voxel_size,
        p=-1),
    dict(
        type='Pack3DDetInputs_',
        keys=['points', 'elastic_coords', 'gt_bboxes_3d', 
              'gt_labels_3d', 'sp_pts_mask'])
]
test_pipeline_arkitscenes = [
    dict(
        type='LoadPointsFromFile',
        coord_type='DEPTH',
        shift_height=False,
        use_color=True,
        load_dim=6,
        use_dim=[0, 1, 2, 3, 4, 5]),
    dict(type='LoadAnnotations3D_',
         with_label_3d=False,
         with_bbox_3d=False,
         with_sp_mask_3d=True),
    dict(
        type='MultiScaleFlipAug3D',
        img_scale=(1333, 800),
        pts_scale_ratio=1,
        flip=False,
        transforms=[
            dict(type='PointSample_', num_points=100000),
            dict(
                type='DenormalizePointsColor',
                color_mean=[0, 0, 0],
                color_std=[255, 255, 255]),
            dict(
                type='NormalizePointsColor_',
                color_mean=[127.5, 127.5, 127.5])
        ]),
    dict(type='Pack3DDetInputs_', keys=['points', 'sp_pts_mask'])
]


# run settings
train_dataloader = dict(
    batch_size=8,
    num_workers=8,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    dataset=dict(
        type='ConcatDataset_',
        datasets=[dict(
                    type=dataset_type_scannetpp,
                    ann_file='scannetpp_infos_train.pkl',
                    partition=0.33,
                    data_prefix=data_prefix_scannetpp,
                    data_root=data_root_scannetpp,
                    pipeline=train_pipeline_scannetpp,
                    test_mode=False)] + \
                [dict(
                    type=dataset_type_arkitscenes,
                    ann_file='arkitscenes_offline_infos_train.pkl',
                    partition=0.08,
                    data_prefix=data_prefix_arkitscenes,
                    data_root=data_root_arkitscenes,
                    pipeline=train_pipeline_arkitscenes,
                    test_mode=False)] 
                    ))

val_dataloader = dict(
    batch_size=1,
    num_workers=1,
    persistent_workers=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type='ConcatDataset_',
        datasets= \
                [dict(
                    type=dataset_type_scannet,
                    ann_file='scannet_infos_val.pkl',
                    data_prefix=data_prefix_scannet,
                    data_root=data_root_scannet,
                    metainfo=metainfo_scannet,
                    pipeline=test_pipeline_scannet,
                    ignore_index=max_class_scannet,
                    test_mode=True)]  
                    ))

test_dataloader = val_dataloader

load_from = 'work_dirs/tmp/oneformer3d_1xb4_scannet.pth'

test_evaluator = dict(type='IndoorMetric_', 
                      datasets=['scannet'],
                      datasets_classes=[classes_scannet])

val_evaluator = test_evaluator

optim_wrapper = dict(
    type='OptimWrapper',
    optimizer=dict(type='AdamW', lr=0.0001 * 2, weight_decay=0.05),
    clip_grad=dict(max_norm=10, norm_type=2))

param_scheduler = dict(type='PolyLR', begin=0, end=1024, power=0.9)

custom_hooks = [dict(type='EmptyCacheHook', after_iter=True)]
default_hooks = dict(
    checkpoint=dict(interval=1, max_keep_ckpts=16))

train_cfg = dict(
    type='EpochBasedTrainLoop',
    max_epochs=1024,
    dynamic_intervals=[(1, 16), (1024 - 16, 1)])
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

col14m avatar Jul 01 '25 20:07 col14m

Thanks for config. May I please clarify:

  1. Looks like even though I don't include scannet in the training dataloader, I need to include it in the model config?
  2. What if I just want to inference on a random sample point cloud where I don't know the classes in it?
  3. Does it mean I need to retrain my model with the scannet info included in the model config?

azztt avatar Jul 02 '25 00:07 azztt