ViTPose icon indicating copy to clipboard operation
ViTPose copied to clipboard

KeyError: 'dataset_idx'

Open MaxTeselkin opened this issue 2 years ago • 2 comments

Hi! I am trying to use ViTPose+ basic for inference using the following code:

# importing necessary libraries
import warnings
warnings.filterwarnings('ignore')
import torch
import torchvision
import cv2
from mmpose.apis import (inference_top_down_pose_model,
                         init_pose_model,
                         vis_pose_result,
                         process_mmdet_results)
from mmdet.apis import inference_detector, init_detector
from IPython.display import Image, display
import tempfile
import os

# define model configs and checkpoints
pose_config = '/kaggle/working/ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/vitPose+_base_coco+aic+mpii+ap10k+apt36k+wholebody_256x192_udp.py'
pose_checkpoint = '/kaggle/input/vitpose-plus-basic/vitpose-plus-b.pth'
det_config = '/kaggle/working/ViTPose/demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py'
det_checkpoint = 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'

# initialize pose model
pose_model = init_pose_model(pose_config, pose_checkpoint, device='cpu')
# initialize detector
det_model = init_detector(det_config, det_checkpoint, device='cpu')

img = '/kaggle/working/ViTPose/tests/data/coco/000000196141.jpg'

# inference detection
mmdet_results = inference_detector(det_model, img)

# extract person (COCO_ID=1) bounding boxes from the detection results
person_results = process_mmdet_results(mmdet_results, cat_id=1)

# inference pose
pose_results, returned_outputs = inference_top_down_pose_model(pose_model,
                                                               img,
                                                               person_results,
                                                               bbox_thr=0.3,
                                                               format='xyxy',
                                                               dataset=pose_model.cfg.data.test.type)

# show pose estimation results
vis_result = vis_pose_result(pose_model,
                             img,
                             pose_results,
                             dataset=pose_model.cfg.data.test.type,
                             show=False)

with tempfile.TemporaryDirectory() as tmpdir:
    file_name = os.path.join(tmpdir, 'pose_results.png')
    cv2.imwrite(file_name, vis_result)
    display(Image(file_name))

But I get the following error:

KeyError                                  Traceback (most recent call last)
/tmp/ipykernel_27/4115304742.py in <module>
     24                                                                bbox_thr=0.3,
     25                                                                format='xyxy',
---> 26                                                                dataset=pose_model.cfg.data.test.type)
     27 
     28 # show pose estimation results

/kaggle/working/ViTPose/mmpose/apis/inference.py in inference_top_down_pose_model(model, img_or_path, person_results, bbox_thr, format, dataset, dataset_info, return_heatmap, outputs)
    404             dataset=dataset,
    405             dataset_info=dataset_info,
--> 406             return_heatmap=return_heatmap)
    407 
    408         if return_heatmap:

/kaggle/working/ViTPose/mmpose/apis/inference.py in _inference_single_pose_model(model, img_or_path, bboxes, dataset, dataset_info, return_heatmap)
    276             data['image_file'] = img_or_path
    277 
--> 278         data = test_pipeline(data)
    279         batch_data.append(data)
    280 

/kaggle/working/ViTPose/mmpose/datasets/pipelines/shared_transform.py in __call__(self, data)
     97         """
     98         for t in self.transforms:
---> 99             data = t(data)
    100             if data is None:
    101                 return None

/kaggle/working/ViTPose/mmpose/datasets/pipelines/shared_transform.py in __call__(self, results)
    166                 else:
    167                     key_src = key_tgt = key
--> 168                 meta[key_tgt] = results[key_src]
    169         if 'bbox_id' in results:
    170             meta['bbox_id'] = results['bbox_id']

KeyError: 'dataset_idx'

MaxTeselkin avatar Jan 22 '23 21:01 MaxTeselkin

Ran into the same issue, I'm guessing it has to do with dataset_idx being populated in the configs. I don't know what the proper fix is, but I found a temporary workaround replacing

img_sources = torch.from_numpy(np.array([ele['dataset_idx'] for ele in img_metas])).to(img.device)

with

img_sources = torch.from_numpy(np.array([0 for ele in img_metas])).to(img.device)

in the 2 places it's used by mmpose. That said, I modified the config to only train on COCO, and I'm not sure my workaround would work otherwise.

Serdnad avatar Jan 24 '23 09:01 Serdnad

Ran into the same issue, I'm guessing it has to do with dataset_idx being populated in the configs. I don't know what the proper fix is, but I found a temporary workaround replacing

img_sources = torch.from_numpy(np.array([ele['dataset_idx'] for ele in img_metas])).to(img.device)

with

img_sources = torch.from_numpy(np.array([0 for ele in img_metas])).to(img.device)

in the 2 places it's used by mmpose. That said, I modified the config to only train on COCO, and I'm not sure my workaround would work otherwise.

Hello, which file did you modify? I run into the same problem

XiongFenghhh avatar Jan 08 '24 11:01 XiongFenghhh