mmpose
mmpose copied to clipboard
When i train with my own datasets, AP = 0
DONE (t=0.00s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets= 20 ] = 0.000 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets= 20 ] = 0.000 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets= 20 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets= 20 ] = -1.000 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets= 20 ] = -1.000 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 20 ] = 0.000 Average Recall (AR) @[ IoU=0.50 | area= all | maxDets= 20 ] = 0.000 Average Recall (AR) @[ IoU=0.75 | area= all | maxDets= 20 ] = 0.000 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets= 20 ] = -1.000 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets= 20 ] = -1.000
@GXYao7 Hi. Thanks for your feedback. The result seems abnormal. Could you please provide more information? For example, the config file you used, the command you run, and the possible modifications you have made, so that we can locate the problem more quickly.
this is my dataset file:animal_lambpose_dataset.py my dataset has 18 keypoint
Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp import tempfile import warnings from collections import OrderedDict, defaultdict
import json_tricks as json import numpy as np from mmcv import Config, deprecated_api_warning from xtcocotools.cocoeval import COCOeval
from ....core.post_processing import oks_nms, soft_oks_nms from ...builder import DATASETS from ..base import Kpt2dSviewRgbImgTopDownDataset
@DATASETS.register_module() class AnimalLambPoseDataset(Kpt2dSviewRgbImgTopDownDataset): """Animal-Pose dataset for animal pose estimation.
"Cross-domain Adaptation For Animal Pose Estimation" ICCV'2019
More details can be found in the `paper
<https://arxiv.org/abs/1908.05806>`__ .
The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information.
Animal-Pose keypoint indexes::
0: 'L_Eye',
1: 'R_Eye',
2: 'L_EarBase',
3: 'R_EarBase',
4: 'Nose',
5: 'Throat',
6: 'TailBase',
7: 'Withers',
8: 'L_F_Elbow',
9: 'R_F_Elbow',
10: 'L_B_Elbow',
11: 'R_B_Elbow',
12: 'L_F_Knee',
13: 'R_F_Knee',
14: 'L_B_Knee',
15: 'R_B_Knee',
16: 'L_F_Paw',
17: 'R_F_Paw',
18: 'L_B_Paw',
19: 'R_B_Paw'
Args:
ann_file (str): Path to the annotation file.
img_prefix (str): Path to a directory where images are held.
Default: None.
data_cfg (dict): config
pipeline (list[dict | callable]): A sequence of data transforms.
dataset_info (DatasetInfo): A class containing all dataset info.
test_mode (bool): Store True when building test or
validation dataset. Default: False.
"""
def __init__(self,
ann_file,
img_prefix,
data_cfg,
pipeline,
dataset_info=None,
test_mode=False):
if dataset_info is None:
warnings.warn(
'dataset_info is missing. '
'Check https://github.com/open-mmlab/mmpose/pull/663 '
'for details.', DeprecationWarning)
cfg = Config.fromfile('configs/_base_/datasets/lamb.py')
dataset_info = cfg._cfg_dict['dataset_info']
super().__init__(
ann_file,
img_prefix,
data_cfg,
pipeline,
dataset_info=dataset_info,
test_mode=test_mode)
self.use_gt_bbox = data_cfg['use_gt_bbox']
self.bbox_file = data_cfg['bbox_file']
self.det_bbox_thr = data_cfg.get('det_bbox_thr', 0.0)
self.use_nms = data_cfg.get('use_nms', True)
self.soft_nms = data_cfg['soft_nms']
self.nms_thr = data_cfg['nms_thr']
self.oks_thr = data_cfg['oks_thr']
self.vis_thr = data_cfg['vis_thr']
self.ann_info['use_different_joint_weights'] = False
self.db = self._get_db()
print(f'=> num_images: {self.num_images}')
print(f'=> load {len(self.db)} samples')
def _get_db(self):
"""Load dataset."""
assert self.use_gt_bbox
gt_db = self._load_coco_keypoint_annotations()
return gt_db
def _load_coco_keypoint_annotations(self):
"""Ground truth bbox and keypoints."""
gt_db = []
for img_id in self.img_ids:
gt_db.extend(self._load_coco_keypoint_annotation_kernel(img_id))
return gt_db
def _load_coco_keypoint_annotation_kernel(self, img_id):
"""load annotation from COCOAPI.
Note:
bbox:[x1, y1, w, h]
Args:
img_id: coco image id
Returns:
dict: db entry
"""
img_ann = self.coco.loadImgs(img_id)[0]
width = img_ann['width']
height = img_ann['height']
num_joints = self.ann_info['num_joints']
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)
objs = self.coco.loadAnns(ann_ids)
# sanitize bboxes
valid_objs = []
for obj in objs:
if 'bbox' not in obj:
continue
x, y, w, h = obj['bbox']
x1 = max(0, x)
y1 = max(0, y)
x2 = min(width - 1, x1 + max(0, w))
y2 = min(height - 1, y1 + max(0, h))
if ('area' not in obj or obj['area'] > 0) and x2 > x1 and y2 > y1:
obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
valid_objs.append(obj)
objs = valid_objs
bbox_id = 0
rec = []
for obj in objs:
if 'keypoints' not in obj:
continue
if max(obj['keypoints']) == 0:
continue
if 'num_keypoints' in obj and obj['num_keypoints'] == 0:
continue
joints_3d = np.zeros((num_joints, 3), dtype=np.float32)
joints_3d_visible = np.zeros((num_joints, 3), dtype=np.float32)
keypoints = np.array(obj['keypoints']).reshape(-1, 3)
joints_3d[:, :2] = keypoints[:, :2]
joints_3d_visible[:, :2] = np.minimum(1, keypoints[:, 2:3])
image_file = osp.join(self.img_prefix, self.id2name[img_id])
rec.append({
'image_file': image_file,
'bbox': obj['clean_bbox'][:4],
'rotation': 0,
'joints_3d': joints_3d,
'joints_3d_visible': joints_3d_visible,
'dataset': self.dataset_name,
'bbox_score': 1,
'bbox_id': bbox_id
})
bbox_id = bbox_id + 1
return rec
@deprecated_api_warning(name_dict=dict(outputs='results'))
def evaluate(self, results, res_folder=None, metric='mAP', **kwargs):
"""Evaluate coco keypoint results. The pose prediction results will be
saved in ``${res_folder}/result_keypoints.json``.
Note:
- batch_size: N
- num_keypoints: K
- heatmap height: H
- heatmap width: W
Args:
results (list[dict]): Testing results containing the following
items:
- preds (np.ndarray[N,K,3]): The first two dimensions are \
coordinates, score is the third dimension of the array.
- boxes (np.ndarray[N,6]): [center[0], center[1], scale[0], \
scale[1],area, score]
- image_paths (list[str]): For example, ['data/coco/val2017\
/000000393226.jpg']
- heatmap (np.ndarray[N, K, H, W]): model output heatmap
- bbox_id (list(int)).
res_folder (str, optional): The folder to save the testing
results. If not specified, a temp folder will be created.
Default: None.
metric (str | list[str]): Metric to be performed. Defaults: 'mAP'.
Returns:
dict: Evaluation results for evaluation metric.
"""
metrics = metric if isinstance(metric, list) else [metric]
allowed_metrics = ['mAP']
for metric in metrics:
if metric not in allowed_metrics:
raise KeyError(f'metric {metric} is not supported')
if res_folder is not None:
tmp_folder = None
res_file = osp.join(res_folder, 'result_keypoints.json')
else:
tmp_folder = tempfile.TemporaryDirectory()
res_file = osp.join(tmp_folder.name, 'result_keypoints.json')
kpts = defaultdict(list)
for result in results:
preds = result['preds']
boxes = result['boxes']
image_paths = result['image_paths']
bbox_ids = result['bbox_ids']
batch_size = len(image_paths)
for i in range(batch_size):
image_id = self.name2id[image_paths[i][len(self.img_prefix):]]
kpts[image_id].append({
'keypoints': preds[i],
'center': boxes[i][0:2],
'scale': boxes[i][2:4],
'area': boxes[i][4],
'score': boxes[i][5],
'image_id': image_id,
'bbox_id': bbox_ids[i]
})
kpts = self._sort_and_unique_bboxes(kpts)
# rescoring and oks nms
num_joints = self.ann_info['num_joints']
vis_thr = self.vis_thr
oks_thr = self.oks_thr
valid_kpts = []
for image_id in kpts.keys():
img_kpts = kpts[image_id]
for n_p in img_kpts:
box_score = n_p['score']
kpt_score = 0
valid_num = 0
for n_jt in range(0, num_joints):
t_s = n_p['keypoints'][n_jt][2]
if t_s > vis_thr:
kpt_score = kpt_score + t_s
valid_num = valid_num + 1
if valid_num != 0:
kpt_score = kpt_score / valid_num
# rescoring
n_p['score'] = kpt_score * box_score
if self.use_nms:
nms = soft_oks_nms if self.soft_nms else oks_nms
keep = nms(list(img_kpts), oks_thr, sigmas=self.sigmas)
valid_kpts.append([img_kpts[_keep] for _keep in keep])
else:
valid_kpts.append(img_kpts)
self._write_coco_keypoint_results(valid_kpts, res_file)
# do evaluation only if the ground truth keypoint annotations exist
if 'annotations' in self.coco.dataset:
info_str = self._do_python_keypoint_eval(res_file)
name_value = OrderedDict(info_str)
if tmp_folder is not None:
tmp_folder.cleanup()
else:
warnings.warn(f'Due to the absence of ground truth keypoint'
f'annotations, the quantitative evaluation can not'
f'be conducted. The prediction results have been'
f'saved at: {osp.abspath(res_file)}')
name_value = {}
return name_value
def _write_coco_keypoint_results(self, keypoints, res_file):
"""Write results into a json file."""
data_pack = [{
'cat_id': self._class_to_coco_ind[cls],
'cls_ind': cls_ind,
'cls': cls,
'ann_type': 'keypoints',
'keypoints': keypoints
} for cls_ind, cls in enumerate(self.classes)
if not cls == '__background__']
results = self._coco_keypoint_results_one_category_kernel(data_pack[0])
with open(res_file, 'w') as f:
json.dump(results, f, sort_keys=True, indent=4)
def _coco_keypoint_results_one_category_kernel(self, data_pack):
"""Get coco keypoint results."""
cat_id = data_pack['cat_id']
keypoints = data_pack['keypoints']
cat_results = []
for img_kpts in keypoints:
if len(img_kpts) == 0:
continue
_key_points = np.array(
[img_kpt['keypoints'] for img_kpt in img_kpts])
key_points = _key_points.reshape(-1,
self.ann_info['num_joints'] * 3)
result = [{
'image_id': img_kpt['image_id'],
'category_id': cat_id,
'keypoints': key_point.tolist(),
'score': float(img_kpt['score']),
'center': img_kpt['center'].tolist(),
'scale': img_kpt['scale'].tolist()
} for img_kpt, key_point in zip(img_kpts, key_points)]
cat_results.extend(result)
return cat_results
def _do_python_keypoint_eval(self, res_file):
"""Keypoint evaluation using COCOAPI."""
coco_det = self.coco.loadRes(res_file)
coco_eval = COCOeval(self.coco, coco_det, 'keypoints', self.sigmas)
coco_eval.params.useSegm = None
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
stats_names = [
'AP', 'AP .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5',
'AR .75', 'AR (M)', 'AR (L)'
]
info_str = list(zip(stats_names, coco_eval.stats))
return info_str
def _sort_and_unique_bboxes(self, kpts, key='bbox_id'):
"""sort kpts and remove the repeated ones."""
for img_id, persons in kpts.items():
num = len(persons)
kpts[img_id] = sorted(kpts[img_id], key=lambda x: x[key])
for i in range(num - 1, 0, -1):
if kpts[img_id][i][key] == kpts[img_id][i - 1][key]:
del kpts[img_id][i]
return kpts
This is the network file:litehrnet_18_coco_256x192.py
base = [ '../../../../base/default_runtime.py', '../../../../base/datasets/lamb.py' ] evaluation = dict(interval=1, metric='mAP', save_best='AP')
optimizer = dict( type='Adam', lr=5e-4, )
optimizer_config = dict(grad_clip=None)
learning policy
lr_config = dict( policy='step', warmup='linear', warmup_iters=500, warmup_ratio=0.001, step=[170, 200]) total_epochs = 210 channel_cfg = dict( num_output_channels=18, dataset_joints=18, dataset_channel=[ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], ], inference_channel=[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17 ])
model settings
model = dict( type='TopDown', pretrained=None, backbone=dict( type='LiteHRNet', in_channels=3, extra=dict( stem=dict(stem_channels=32, out_channels=32, expand_ratio=1), num_stages=3, stages_spec=dict( num_modules=(2, 4, 2), num_branches=(2, 3, 4), num_blocks=(2, 2, 2), module_type=('LITE', 'LITE', 'LITE'), with_fuse=(True, True, True), reduce_ratios=(8, 8, 8), num_channels=( (40, 80), (40, 80, 160), (40, 80, 160, 320), )), with_head=True, )), keypoint_head=dict( type='TopdownHeatmapSimpleHead', in_channels=40, out_channels=channel_cfg['num_output_channels'], num_deconv_layers=0, extra=dict(final_conv_kernel=1, ), loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)), train_cfg=dict(), test_cfg=dict( flip_test=True, post_process='default', shift_heatmap=True, modulate_kernel=11))
data_cfg = dict( image_size=[256, 256], heatmap_size=[64, 64], num_output_channels=channel_cfg['num_output_channels'], num_joints=channel_cfg['dataset_joints'], dataset_channel=channel_cfg['dataset_channel'], inference_channel=channel_cfg['inference_channel'], use_gt_bbox = True, soft_nms=False, nms_thr=1.0, oks_thr=0.9, vis_thr=0.2, det_bbox_thr=0.0, bbox_file='', )
train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='TopDownGetBboxCenterScale', padding=1.25), dict(type='TopDownRandomShiftBboxCenter', shift_factor=0.16, prob=0.3), dict(type='TopDownRandomFlip', flip_prob=0.5), dict( type='TopDownHalfBodyTransform', num_joints_half_body=8, prob_half_body=0.3), dict( type='TopDownGetRandomScaleRotation', rot_factor=30, scale_factor=0.25), dict(type='TopDownAffine'), dict(type='ToTensor'), dict( type='NormalizeTensor', mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), dict(type='TopDownGenerateTarget', sigma=2), dict( type='Collect', keys=['img', 'target', 'target_weight'], meta_keys=[ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation', 'bbox_score', 'flip_pairs' ]), ]
val_pipeline = [ dict(type='LoadImageFromFile'), dict(type='TopDownGetBboxCenterScale', padding=1.25), dict(type='TopDownAffine'), dict(type='ToTensor'), dict( type='NormalizeTensor', mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), dict( type='Collect', keys=['img'], meta_keys=[ 'image_file', 'center', 'scale', 'rotation', 'bbox_score', 'flip_pairs' ]), ]
test_pipeline = val_pipeline
data_root = 'datacoco' data = dict( samples_per_gpu=4, workers_per_gpu=2, train=dict( type='AnimalLambPoseDataset', ann_file=f'{data_root}/annotations/keypoints_train.json', img_prefix=f'{data_root}/train/', data_cfg=data_cfg, pipeline=train_pipeline, dataset_info={{base.dataset_info}}), val=dict( type='AnimalLambPoseDataset', ann_file=f'{data_root}/annotations/keypoints_val.json', img_prefix=f'{data_root}/val/', data_cfg=data_cfg, pipeline=val_pipeline, dataset_info={{base.dataset_info}}), test=dict( type='AnimalLambPoseDataset', ann_file=f'{data_root}/annotations/keypoints_val.json', img_prefix=f'{data_root}/val/', data_cfg=data_cfg, pipeline=test_pipeline, dataset_info={{base.dataset_info}}), )
evaluation = dict(interval=1, metric='mAP', save_best='AP')
Hi. I noticed that in the config file, you set the evaluation interval to be 1, which means the model will do evaluation every 1 epoch.
Are the above results of the first several epochs? If so, the model may not perform well in the first several epochs and the AP is possible to be zero. You may wait some epochs and see if the results turn better.
Okay, I'll turn it up a little and see the results, thank you~
When I ran 210 epochs, the AP was still 0
@GXYao7 Sorry for the late reply. Have you checked the losses during the training? Are they normal or abnormal?
I found that from the beginning to the end of the training, my loss always kept a low value without ups and downs, and the data_ time is basically 0, which may be the reason why the AP is 0. Do you have any suggestions for solution?Is this my dataset problem?
the data_ time is basically 0
This is normal because this indicated that the time for loading data samples is relatively small.
my loss always kept a low value without ups and downs
Ths loss being a low value is normal. But I think this is abnormal that it did not change. You may check whether the data is loaded correctly. For example, double-check your data_root:
data_root = 'datacoco'
Okay, I'll check that, thank you~
好的,我去看看,谢谢~
你好,我也遇到了这个问题,请问一下怎么解决的呢?
很抱歉,我也暂时还没有找到是哪儿的问题
你这个应该是验证集中的json里面的area有问题,算ap或者ar会通过area和maxDets去筛框,然后才回去计算ap或者ar,你json中的area值有问题可能出现你这种情况
好的好的,谢谢你呀,我去检查一下我的json文件
你这个应该是验证集中的json里面的area有问题,算ap或者ar会通过area和maxDets去筛框,然后才回去计算ap或者ar,你json中的area值有问题可能出现你这种情况
谢谢你,按照你说的,这个问题终于解决了
area多大才是正常呢,你的area是什么错误呢
area多大才是正常呢,你的area是什么错误呢
这个问题可能是在如coco数据集中的keypoints的标签在转换成mmpose的数据格式中,抄了一些网上的转换脚本,我看了网上的脚本在area的转换中是直接给定0,导致mmpose在计算指标的时候通过框的面积直接把标签框全干掉了。正常的标签框area在转换的过程中是标签框的w*h,
area多大才是正常呢,你的area是什么错误呢
这个问题可能是在如coco数据集中的keypoints的标签在转换成mmpose的数据格式中,抄了一些网上的转换脚本,我看了网上的脚本在area的转换中是直接给定0,导致mmpose在计算指标的时候通过框的面积直接把标签框全干掉了。正常的标签框area在转换的过程中是标签框的w*h,
那怎么解决呢,我看生成的area值还是挺大的
标签的area是多大就生成多大,不存在太大的问题。
彼岸天 @.***
------------------ 原始邮件 ------------------ 发件人: @.>; 发送时间: 2023年2月14日(星期二) 上午10:22 收件人: @.>; 抄送: @.>; @.>; 主题: Re: [open-mmlab/mmpose] When i train with my own datasets, AP = 0 (Issue #1795)
area多大才是正常呢,你的area是什么错误呢
这个问题可能是在如coco数据集中的keypoints的标签在转换成mmpose的数据格式中,抄了一些网上的转换脚本,我看了网上的脚本在area的转换中是直接给定0,导致mmpose在计算指标的时候通过框的面积直接把标签框全干掉了。正常的标签框area在转换的过程中是标签框的w*h,
那怎么解决呢,我看生成的area值还是挺大的
— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.***>
那就很奇怪,我的ap就是0,和sigmas的取值有关系吗,我调试的时候看到cocoeval里的dt就会变成0
Was a solution found for this issue?
Sharing in case it helps anyone: I found this link and I learned that I my image_id and id were mixed up, so I was pointing to the wrong annotation, which is why i got all AP 0. After fixing this, training converges!