deep-high-resolution-net.pytorch
deep-high-resolution-net.pytorch copied to clipboard
About the multi-scale MPII test result, i only got 91.6% instead of 92.3% reported in the paper.
I have implemented the multi-scale testing, and I have verified that the MPII validation set accuracy is 90.75% I then go on the apply it to the test set and the accuracy i got is only:
& Head & Shoulder & Elbow & Wrist & Hip & Knee & Ankle & UBody & Total
& 98.3 & 96.5 & 92.4 & 88.3 & 90.6 & 88.3 & 84.1 & 92.4 & 91.6
AUC: 61.6
Which is the 92.3% reported in the paper. Below are the code i have used for the multi-scale testing:
def read_scaled_image(image_file, s, center, scale, image_size, COLOR_RGB, DATA_FORMAT, image_transform):
if DATA_FORMAT == 'zip':
from utils import zipreader
data_numpy = zipreader.imread(image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
else:
data_numpy = cv2.imread(image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
if COLOR_RGB:
data_numpy = cv2.cvtColor(data_numpy, cv2.COLOR_BGR2RGB)
trans = get_affine_transform(center, s * scale, 0, image_size)
images_warp = cv2.warpAffine(data_numpy, trans, tuple(image_size), flags=cv2.INTER_LINEAR)
return image_transform(images_warp)
def validate(config, val_loader, val_dataset, model, criterion, output_dir, tb_log_dir, writer_dict=None, test_scale=None):
batch_time = AverageMeter()
losses = AverageMeter()
acc = AverageMeter()
# switch to evaluate mode
model.eval()
num_samples = len(val_dataset)
all_preds = np.zeros((num_samples, config.MODEL.NUM_JOINTS, 3), dtype=np.float32)
all_boxes = np.zeros((num_samples, 6))
image_path = []
filenames = []
imgnums = []
idx = 0
PRINT_FREQ = min(config.PRINT_FREQ//10, 5)
image_size = np.array(config.MODEL.IMAGE_SIZE)
final_test_scale = test_scale if test_scale is not None else config.TEST.SCALE_FACTOR
with torch.no_grad():
end = time.time()
def scale_back_output(output_hm, s, output_size):
hm_size = [output_hm.size(3), output_hm.size(2)]
if s != 1.0:
hm_w_margin = int(abs(1.0 - s) * hm_size[0] / 2.0)
hm_h_margin = int(abs(1.0 - s) * hm_size[1] / 2.0)
if s < 1.0:
hm_padding = torch.nn.ZeroPad2d((hm_w_margin, hm_w_margin, hm_h_margin, hm_h_margin))
resized_hm = hm_padding(output_hm)
else:
resized_hm = output_hm[:, :, hm_h_margin:hm_size[0] - hm_h_margin, hm_w_margin:hm_size[1] - hm_w_margin]
resized_hm = torch.nn.functional.interpolate(
resized_hm,
size=(output_size[1], output_size[0]),
mode='bilinear', # bilinear bicubic
align_corners=False
)
else:
resized_hm = output_hm
if hm_size[0] != output_size[0] or hm_size[1] != output_size[1]:
resized_hm = torch.nn.functional.interpolate(
resized_hm,
size=(output_size[1], output_size[0]),
mode='bilinear', # bilinear bicubic
align_corners=False
)
# resized_hm = torch.nn.functional.normalize(resized_hm, dim=[2, 3], p=1)
resized_hm = resized_hm/(torch.sum(resized_hm, dim=[2, 3], keepdim=True) + 1e-9)
return resized_hm
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
image_transform = transforms.Compose([transforms.ToTensor(), normalize])
thread_pool = multiprocessing.Pool(multiprocessing.cpu_count())
start_time = time.time()
for i, (input, target, target_weight, meta) in enumerate(val_loader):
# compute output
# print("Batch", i, "Batch Size", input.size(0))
target = target.cuda(non_blocking=True)
target_weight = target_weight.cuda(non_blocking=True)
outputs = []
for sidx, s in enumerate(sorted(final_test_scale, reverse=True)):
print("Test Scale", s)
if s != 1.0:
image_files = meta["image"]
centers = meta["center"].numpy()
scales = meta["scale"].numpy()
# images_resized = []
# for (image_file, center, scale) in zip(image_files, centers, scales):
# scaled_image = read_scaled_image(image_file, center, scale, config.DATASET.COLOR_RGB)
# images_resized.append(scaled_image)
images_resized = thread_pool.starmap(read_scaled_image,
[(image_file, s, center, scale, image_size, config.DATASET.COLOR_RGB, config.DATASET.DATA_FORMAT, image_transform) for (image_file, center, scale) in zip(image_files, centers, scales)])
images_resized = torch.stack(images_resized, dim=0)
else:
images_resized = input
model_outputs = model(images_resized)
hm_size = [model_outputs.size(3), model_outputs.size(2)]
# hm_size = image_size
# hm_size = [128, 128]
if config.TEST.FLIP_TEST:
print("Test Flip")
input_flipped = images_resized.flip(3)
output_flipped = model(input_flipped)
if isinstance(output_flipped, list):
output_flipped = output_flipped[-1]
else:
output_flipped = output_flipped
output_flipped = flip_back(output_flipped.cpu().numpy(), val_dataset.flip_pairs)
output_flipped = torch.from_numpy(output_flipped.copy()).cuda()
# feature is not aligned, shift flipped heatmap for higher accuracy
if config.TEST.SHIFT_HEATMAP:
output_flipped[:, :, :, 1:] = output_flipped.clone()[:, :, :, 0:-1]
model_outputs = 0.5 * (model_outputs + output_flipped)
# output_flipped_resized = scale_back_output(output_flipped, s, hm_size)
# outputs.append(output_flipped_resized)
output_flipped_resized = scale_back_output(model_outputs, s, hm_size)
outputs.append(output_flipped_resized)
target_size = [target.size(3), target.size(2)]
if hm_size[0] != target_size[0] or hm_size[1] != target_size[1]:
target = torch.nn.functional.interpolate(
target,
size=hm_size,
mode='bilinear', # bilinear bicubic
align_corners=False
)
target = torch.nn.functional.normalize(target, dim=[2, 3], p=2)
for indv_output in outputs:
_, avg_acc, _, _ = accuracy(indv_output.cpu().numpy(), target.cpu().numpy())
print("Indv Accuracy", avg_acc)
output = torch.stack(outputs, dim=0).mean(dim=0)
loss = criterion(output, target, target_weight)
num_images = input.size(0)
# measure accuracy and record loss
losses.update(loss.item(), num_images)
_, avg_acc, cnt, pred = accuracy(output.cpu().numpy(), target.cpu().numpy())
print("Avg Accuracy", avg_acc)
acc.update(avg_acc, cnt)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
c = meta['center'].numpy()
s = meta['scale'].numpy()
score = meta['score'].numpy()
preds, maxvals = get_final_preds(config, output.clone().cpu().numpy(), c, s)
all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
all_preds[idx:idx + num_images, :, 2:3] = maxvals
# double check this all_boxes parts
all_boxes[idx:idx + num_images, 0:2] = c[:, 0:2]
all_boxes[idx:idx + num_images, 2:4] = s[:, 0:2]
all_boxes[idx:idx + num_images, 4] = np.prod(s*200, 1)
all_boxes[idx:idx + num_images, 5] = score
image_path.extend(meta['image'])
idx += num_images
if i % PRINT_FREQ == 0:
msg = 'Test: [{0}/{1}]\t' \
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(i, len(val_loader), batch_time=batch_time, loss=losses, acc=acc)
logger.info(msg)
prefix = '{}_{}'.format(os.path.join(output_dir, 'val'), i)
save_debug_images(config, input, meta, target, pred*4, output, prefix)
total_duration = time.time() - start_time
logger.info("Total test time: {:.1f}".format(total_duration))
name_values, perf_indicator = val_dataset.evaluate(config, all_preds, output_dir, all_boxes, image_path, filenames, imgnums)
model_name = config.MODEL.NAME
if isinstance(name_values, list):
for name_value in name_values:
_print_name_value(name_value, model_name)
else:
_print_name_value(name_values, model_name)
if writer_dict:
writer = writer_dict['writer']
global_steps = writer_dict['valid_global_steps']
writer.add_scalar('valid_loss', losses.avg, global_steps)
writer.add_scalar('valid_acc', acc.avg, global_steps)
if isinstance(name_values, list):
for name_value in name_values:
writer.add_scalars('valid', dict(name_value), global_steps)
else:
writer.add_scalars('valid', dict(name_values), global_steps)
writer_dict['valid_global_steps'] = global_steps + 1
return perf_indicator
请问您是如何使用测试集得到结果的? How did you use the test set to get the results?
@sunke123 First of all, MPII released the test set ground truth. So you can download it from http://human-pose.mpi-inf.mpg.de/#download. Also,, if you want to use the pred.mat generated by this HRNet repo, you need to modify the evaluatePCKh.m to the below:
% Evaluate performance by comparing predictions to ground truth annotations.
%%% OPTIONS %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% IDs of prediction sets to include in results
PRED_IDS = [1, 2, 3, 4, 5, 6];
% Subset of the data that the predictions correspond to ('val' or 'train')
plotcurve = false;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
addpath ('eval')
fprintf('# MPII single-person pose evaluation script\n')
range = 0:0.01:0.5;
tableDir = './latex'; if (~exist(tableDir,'dir')), mkdir(tableDir); end
plotsDir = './plots'; if (~exist(plotsDir,'dir')), mkdir(plotsDir); end
tableTex = cell(length(PRED_IDS)+1,1);
% load ground truth
p = getExpParams(-1)
load([p.gtDir '/annolist_dataset_v12'], 'annolist');
load([p.gtDir '/mpii_human_pose_v1_u12'], 'RELEASE');
annolist_test = annolist(RELEASE.img_train == 0);
% evaluate on the "single person" subset only
single_person_test = RELEASE.single_person(RELEASE.img_train == 0);
% convert to annotation list with a single pose per entry
[annolist_test_flat, single_person_test_flat] = flatten_annolist(annolist_test,single_person_test);
% represent ground truth as a matrix 2x14xN_images
gt = annolist2matrix(annolist_test_flat(single_person_test_flat == 1));
% compute head size
headSize = getHeadSizeAll(annolist_test_flat(single_person_test_flat == 1));
pckAll = zeros(length(range),16,length(PRED_IDS));
for i = 1:length(PRED_IDS);
% load predictions
p = getExpParams(PRED_IDS(i));
try
load(p.predFilename, 'preds');
catch
preds = h5read(p.predFilename, '/preds');
end
if size(preds, 1) == 2
preds = permute(preds, [3, 2, 1]);
end
% Check that there are the same number of predictions and ground truth
% annotations. If this assertion fails, a likely cause is a mismatch in
% subsets (eg predictions are for the training set but ground truth
% annotations are for the validation set).
fprintf('%d\n', length(preds))
fprintf('%d\n', length(gt))
assert(length(preds) == length(gt));
pred_flat = annolist_test_flat(single_person_test_flat == 1);
for idx = 1:length(preds);
for pidx = 1:length(pred_flat(idx).annorect.annopoints.point);
joint = pred_flat(idx).annorect.annopoints.point(pidx).id + 1;
xy = preds(idx, joint, :);
pred_flat(idx).annorect.annopoints.point(pidx).x = xy(1);
pred_flat(idx).annorect.annopoints.point(pidx).y = xy(2);
end
end
% pred = annolist2matrix(pred_flat(single_person_flat == 1));
pred = annolist2matrix(pred_flat);
% only gt is allowed to have NaN
pred(isnan(pred)) = inf;
% compute distance to ground truth joints
dist = getDistPCKh(pred,gt,headSize);
% compute PCKh
pck = computePCK(dist,range);
% plot results
[row, header] = genTablePCK(pck(end,:),p.name);
tableTex{1} = header;
tableTex{i+1} = row;
pckAll(:,:,i) = pck;
auc = area_under_curve(scale01(range),pck(:,end));
fprintf('%s, AUC: %1.1f\n',p.name,auc);
end
% Save results
fid = fopen([tableDir '/pckh.tex'],'wt');assert(fid ~= -1);
for i=1:length(tableTex),fprintf(fid,'%s\n',tableTex{i}); end; fclose(fid);
% plot curves
bSave = true;
if (plotcurve)
plotCurveNew(squeeze(pckAll(:,end,:)),range,PRED_IDS,'PCKh total, MPII',[plotsDir '/pckh-total-mpii'],bSave,range(1:5:end));
plotCurveNew(squeeze(mean(pckAll(:,[1 6],:),2)),range,PRED_IDS,'PCKh ankle, MPII',[plotsDir '/pckh-ankle-mpii'],bSave,range(1:5:end));
plotCurveNew(squeeze(mean(pckAll(:,[2 5],:),2)),range,PRED_IDS,'PCKh knee, MPII',[plotsDir '/pckh-knee-mpii'],bSave,range(1:5:end));
plotCurveNew(squeeze(mean(pckAll(:,[3 4],:),2)),range,PRED_IDS,'PCKh hip, MPII',[plotsDir '/pckh-hip-mpii'],bSave,range(1:5:end));
plotCurveNew(squeeze(mean(pckAll(:,[7 12],:),2)),range,PRED_IDS,'PCKh wrist, MPII',[plotsDir '/pckh-wrist-mpii'],bSave,range(1:5:end));
plotCurveNew(squeeze(mean(pckAll(:,[8 11],:),2)),range,PRED_IDS,'PCKh elbow, MPII',[plotsDir '/pckh-elbow-mpii'],bSave,range(1:5:end));
plotCurveNew(squeeze(mean(pckAll(:,[9 10],:),2)),range,PRED_IDS,'PCKh shoulder, MPII',[plotsDir '/pckh-shoulder-mpii'],bSave,range(1:5:end));
plotCurveNew(squeeze(mean(pckAll(:,[13 14],:),2)),range,PRED_IDS,'PCKh head, MPII',[plotsDir '/pckh-head-mpii'],bSave,range(1:5:end));
end
display('Done.')
Thank you very much for your reply and sharing. I've been using the evaluateAP. m previously, and try many times but failed the whole time. The reason is that the number of groundtruth is not consistent with the predicted number(use test.json/7247). Now I can use your sharing to successfully get the test set results is 91.8%: &Head & Shoulder & Elbow & Wrist & Hip & Knee & Ankle & UBody & Total \ hrnet-test& 98.5 & 96.6 & 92.3 & 88.5 & 90.8 & 88.9 & 84.5 & 92.5 & 91.8 \ AUC: 63.0
1、could you please tell me the way to set "test_scale" is test_scale=[1,2,3,4,5,6] right ?
but RuntimeError: invalid argument 2: input and output sizes should be greater than 0, but got input (H: 0, W: 0) output (H: 64, W: 64).
2、i get model_outputs = model(images_resized) is a list. is model_outputs = model(images_resized)[-1] right?
3、thank you for your replay!!
1、could you please tell me the way to set "test_scale" is test_scale=[1,2,3,4,5,6] right ? but
RuntimeError: invalid argument 2: input and output sizes should be greater than 0, but got input (H: 0, W: 0) output (H: 64, W: 64). 2、i getmodel_outputs = model(images_resized)is a list. ismodel_outputs = model(images_resized)[-1]right? 3、thank you for your replay!!
or test_scale=[0.5, 0.6, 0.7, 0.8, 0.9] ? what's your settings?
1、could you please tell me the way to set "test_scale" is test_scale=[1,2,3,4,5,6] right ? but
RuntimeError: invalid argument 2: input and output sizes should be greater than 0, but got input (H: 0, W: 0) output (H: 64, W: 64). 2、i getmodel_outputs = model(images_resized)is a list. ismodel_outputs = model(images_resized)[-1]right? 3、thank you for your replay!!
@onepiece666
- The test_scale i used is the [1.0, 1.3, 1.2, 1.1, 0.9, 0.8]. It should be scales mentioned in the paper.
- If you are using the original source code of HRNet, the output should not be a list as there is an only a single set of prediction heatmaps. If will be a list if you are doing intermediate supervision (multiple prediction heatmpas/loss functions)
- You are welcome!
@sunke123 First of all, MPII released the test set ground truth. So you can download it from http://human-pose.mpi-inf.mpg.de/#download. Also,, if you want to use the pred.mat generated by this HRNet repo, you need to modify the evaluatePCKh.m to the below:
% Evaluate performance by comparing predictions to ground truth annotations. %%% OPTIONS %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % IDs of prediction sets to include in results PRED_IDS = [1, 2, 3, 4, 5, 6]; % Subset of the data that the predictions correspond to ('val' or 'train') plotcurve = false; %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% addpath ('eval') fprintf('# MPII single-person pose evaluation script\n') range = 0:0.01:0.5; tableDir = './latex'; if (~exist(tableDir,'dir')), mkdir(tableDir); end plotsDir = './plots'; if (~exist(plotsDir,'dir')), mkdir(plotsDir); end tableTex = cell(length(PRED_IDS)+1,1); % load ground truth p = getExpParams(-1) load([p.gtDir '/annolist_dataset_v12'], 'annolist'); load([p.gtDir '/mpii_human_pose_v1_u12'], 'RELEASE'); annolist_test = annolist(RELEASE.img_train == 0); % evaluate on the "single person" subset only single_person_test = RELEASE.single_person(RELEASE.img_train == 0); % convert to annotation list with a single pose per entry [annolist_test_flat, single_person_test_flat] = flatten_annolist(annolist_test,single_person_test); % represent ground truth as a matrix 2x14xN_images gt = annolist2matrix(annolist_test_flat(single_person_test_flat == 1)); % compute head size headSize = getHeadSizeAll(annolist_test_flat(single_person_test_flat == 1)); pckAll = zeros(length(range),16,length(PRED_IDS)); for i = 1:length(PRED_IDS); % load predictions p = getExpParams(PRED_IDS(i)); try load(p.predFilename, 'preds'); catch preds = h5read(p.predFilename, '/preds'); end if size(preds, 1) == 2 preds = permute(preds, [3, 2, 1]); end % Check that there are the same number of predictions and ground truth % annotations. If this assertion fails, a likely cause is a mismatch in % subsets (eg predictions are for the training set but ground truth % annotations are for the validation set). fprintf('%d\n', length(preds)) fprintf('%d\n', length(gt)) assert(length(preds) == length(gt)); pred_flat = annolist_test_flat(single_person_test_flat == 1); for idx = 1:length(preds); for pidx = 1:length(pred_flat(idx).annorect.annopoints.point); joint = pred_flat(idx).annorect.annopoints.point(pidx).id + 1; xy = preds(idx, joint, :); pred_flat(idx).annorect.annopoints.point(pidx).x = xy(1); pred_flat(idx).annorect.annopoints.point(pidx).y = xy(2); end end % pred = annolist2matrix(pred_flat(single_person_flat == 1)); pred = annolist2matrix(pred_flat); % only gt is allowed to have NaN pred(isnan(pred)) = inf; % compute distance to ground truth joints dist = getDistPCKh(pred,gt,headSize); % compute PCKh pck = computePCK(dist,range); % plot results [row, header] = genTablePCK(pck(end,:),p.name); tableTex{1} = header; tableTex{i+1} = row; pckAll(:,:,i) = pck; auc = area_under_curve(scale01(range),pck(:,end)); fprintf('%s, AUC: %1.1f\n',p.name,auc); end % Save results fid = fopen([tableDir '/pckh.tex'],'wt');assert(fid ~= -1); for i=1:length(tableTex),fprintf(fid,'%s\n',tableTex{i}); end; fclose(fid); % plot curves bSave = true; if (plotcurve) plotCurveNew(squeeze(pckAll(:,end,:)),range,PRED_IDS,'PCKh total, MPII',[plotsDir '/pckh-total-mpii'],bSave,range(1:5:end)); plotCurveNew(squeeze(mean(pckAll(:,[1 6],:),2)),range,PRED_IDS,'PCKh ankle, MPII',[plotsDir '/pckh-ankle-mpii'],bSave,range(1:5:end)); plotCurveNew(squeeze(mean(pckAll(:,[2 5],:),2)),range,PRED_IDS,'PCKh knee, MPII',[plotsDir '/pckh-knee-mpii'],bSave,range(1:5:end)); plotCurveNew(squeeze(mean(pckAll(:,[3 4],:),2)),range,PRED_IDS,'PCKh hip, MPII',[plotsDir '/pckh-hip-mpii'],bSave,range(1:5:end)); plotCurveNew(squeeze(mean(pckAll(:,[7 12],:),2)),range,PRED_IDS,'PCKh wrist, MPII',[plotsDir '/pckh-wrist-mpii'],bSave,range(1:5:end)); plotCurveNew(squeeze(mean(pckAll(:,[8 11],:),2)),range,PRED_IDS,'PCKh elbow, MPII',[plotsDir '/pckh-elbow-mpii'],bSave,range(1:5:end)); plotCurveNew(squeeze(mean(pckAll(:,[9 10],:),2)),range,PRED_IDS,'PCKh shoulder, MPII',[plotsDir '/pckh-shoulder-mpii'],bSave,range(1:5:end)); plotCurveNew(squeeze(mean(pckAll(:,[13 14],:),2)),range,PRED_IDS,'PCKh head, MPII',[plotsDir '/pckh-head-mpii'],bSave,range(1:5:end)); end display('Done.')
I run the code, but get the error Reference to non-existent field 'predFilename'.. How should I run the code? I'm not familiar with matlab. :sweat:
@Sunstin I run the code but I cannot get the results of 91.8. Could you tell me how you do the testing more detailly?
@touhou-ayaya Have you downloaded the evaluation script from MPII (http://human-pose.mpi-inf.mpg.de/results/mpii_human_pose/evalMPII.zip)? After downloading and unzip it, copy my updated evaluation scrip above and you can directly evaluate MPII accuracy from the pred.m generated by this repo.
@Sunstin I run the code but I cannot get the results of 91.8. Could you tell me how you do the testing more detailly?
What accuracy did you get? Are you using the pre-trained model?
@Sunstin I run the code but I cannot get the results of 91.8. Could you tell me how you do the testing more detailly?
What accuracy did you get? Are you using the pre-trained model?
I use their provided pre-trained model and get 91.6
@Sunstin I run the code but I cannot get the results of 91.8. Could you tell me how you do the testing more detailly?
What accuracy did you get? Are you using the pre-trained model?
@MaxChu719 I didn't change the original code, and i use the pre-trained model.
@Sunstin that means your result is using only the flip test. You can use the multi-scale code i post to obtain the multi-scale test, which should give your 91.8%
@Sunstin that means your result is using only the flip test. You can use the multi-scale code i post to obtain the multi-scale test, which should give your 91.8%
@MaxChu719 I'll try. Thank you very much for your advice.
@touhou-ayaya Have you downloaded the evaluation script from MPII (http://human-pose.mpi-inf.mpg.de/results/mpii_human_pose/evalMPII.zip)? After downloading and unzip it, copy my updated evaluation scrip above and you can directly evaluate MPII accuracy from the pred.m generated by this repo.
@MaxChu719
Thanks for your replying, and I am sorry for my delays in replying:worried:, I have tried to run the code followed your answer.
I download and unzip evalMPII.zip, get a directory 'eval'. Then, I create a directory ground_truth in eval and put the ground-truth files in to it (include annolist_dataset_v12.mat, groups_v12.mat, mpii_human_pose_v1_u12.mat, test.h5, train.h5 and valid.h5), the prediction result (pred.mat) is also put into eval. But when I run the code evaluatePCKh.m, get the error Reference to non-existent field 'predFilename'. I try to create a directory preds and put prediction result into it, and move preds into eval, but there is same error. If you don't mind, could you show the directory tree of your evaluation code.
==My English is bad.:worried:==
@touhou-ayaya Have you downloaded the evaluation script from MPII (http://human-pose.mpi-inf.mpg.de/results/mpii_human_pose/evalMPII.zip)? After downloading and unzip it, copy my updated evaluation scrip above and you can directly evaluate MPII accuracy from the pred.m generated by this repo.
@MaxChu719 Thanks for your replying, and I am sorry for my delays in replying😟, I have tried to run the code followed your answer. I download and unzip evalMPII.zip, get a directory 'eval'. Then, I create a directory
ground_truthinevaland put the ground-truth files in to it (include annolist_dataset_v12.mat, groups_v12.mat, mpii_human_pose_v1_u12.mat, test.h5, train.h5 and valid.h5), the prediction result (pred.mat) is also put intoeval. But when I run the codeevaluatePCKh.m, get the errorReference to non-existent field 'predFilename'. I try to create a directorypredsand put prediction result into it, and movepredsintoeval, but there is same error. If you don't mind, could you show the directory tree of your evaluation code.==My English is bad.😟==
I think the error means that there is no defination of predFilename, you can check the file getExpParams.m and give it a defination acording to the pred.mat path.
Hello.Thank you for sharing your code. I have test the model file traind on MPII train dataset the repo provide on MPII valid dataset. But I got different results with matlab and python code. The python version is same as the result showed in paper. So is there any difference? And I tried the muti-scale code you provide on valid dataset, also it is different to the result on the paper.
@XiyueSun 十分感谢:smile:,我按照您说的修改了getExpParams.m,代码运行成功了:smile:
@XiyueSun my multi-scale should reproduce successfully the validation accuracy. However, it cannot reproduce the test accuracy.
Hello.Thank you for sharing your code. I have test the model file traind on MPII train dataset the repo provide on MPII valid dataset. But I got different results with matlab and python code. The python version is same as the result showed in paper. So is there any difference? And I tried the muti-scale code you provide on valid dataset, also it is different to the result on the paper.
python version?how can i test the model provided by the code with the python version?
Hello.Thank you for sharing your code. I have test the model file traind on MPII train dataset the repo provide on MPII valid dataset. But I got different results with matlab and python code. The python version is same as the result showed in paper. So is there any difference? And I tried the muti-scale code you provide on valid dataset, also it is different to the result on the paper.
python version?how can i test the model provided by the code with the python version?
you mean you get the gt_test.mat?
Hello.Thank you for sharing your code. I have test the model file traind on MPII train dataset the repo provide on MPII valid dataset. But I got different results with matlab and python code. The python version is same as the result showed in paper. So is there any difference? And I tried the muti-scale code you provide on valid dataset, also it is different to the result on the paper.
python version?how can i test the model provided by the code with the python version?
Yep, I get the result and mat format result. You CAN test python version on MPII validate dataset, but for test dataset, there are no test annotations to evaluate for caculate the result as far as I know.
Thank you very much for your reply and sharing. I've been using the evaluateAP. m previously, and try many times but failed the whole time. The reason is that the number of groundtruth is not consistent with the predicted number(use test.json/7247). Now I can use your sharing to successfully get the test set results is 91.8%: &Head & Shoulder & Elbow & Wrist & Hip & Knee & Ankle & UBody & Total
hrnet-test& 98.5 & 96.6 & 92.3 & 88.5 & 90.8 & 88.9 & 84.5 & 92.5 & 91.8
AUC: 63.0
Hello, I also met the same problem that the number of groundtruth is not consistent with the predicted number when I run the
provided evaluatePCKh.m code above. Do you have any idea to figure it out? Thx!
Thank you very much for your reply and sharing. I've been using the evaluateAP. m previously, and try many times but failed the whole time. The reason is that the number of groundtruth is not consistent with the predicted number(use test.json/7247). Now I can use your sharing to successfully get the test set results is 91.8%: &Head & Shoulder & Elbow & Wrist & Hip & Knee & Ankle & UBody & Total hrnet-test& 98.5 & 96.6 & 92.3 & 88.5 & 90.8 & 88.9 & 84.5 & 92.5 & 91.8 AUC: 63.0
Hello, I also met the same problem that the number of groundtruth is not consistent with the predicted number when I run the provided evaluatePCKh.m code above. Do you have any idea to figure it out? Thx!
Have you solved this problem? I have the problem that the length of GT and preds is inconsistent with this code. I don't know how you solve it
Thank you very much for your reply and sharing. I've been using the evaluateAP. m previously, and try many times but failed the whole time. The reason is that the number of groundtruth is not consistent with the predicted number(use test.json/7247). Now I can use your sharing to successfully get the test set results is 91.8%: &Head & Shoulder & Elbow & Wrist & Hip & Knee & Ankle & UBody & Total hrnet-test& 98.5 & 96.6 & 92.3 & 88.5 & 90.8 & 88.9 & 84.5 & 92.5 & 91.8 AUC: 63.0
Hello, I also met the same problem that the number of groundtruth is not consistent with the predicted number when I run the provided evaluatePCKh.m code above. Do you have any idea to figure it out? Thx!
Have you solved this problem? I have the problem that the length of GT and preds is inconsistent with this code. I don't know how you solve it
@XiyueSun 十分感谢smile,我按照您说的修改了
getExpParams.m,代码运行成功了smile
Have you solved this problem? I have the problem that the length of GT and preds is inconsistent with this code. I don't know how you solve it
你好,我是张仁杰,我已经收到你的邮件~我会尽快阅读你发来的邮件~
@MaxChu719 Hello,why i can't find test set ground truth on http://human-pose.mpi-inf.mpg.de/#download? Please help me!
@XiyueSun 十分感谢😄,我按照您说的修改了
getExpParams.m,代码运行成功了😄
@touhou-ayaya Can you tell me how did you change it?
@MaxChu719 Hello,why i can't find test set ground truth on http://human-pose.mpi-inf.mpg.de/#download? Please help me!
it seems they removed it. I am uploading the ground truth file to my google drive, will post it here once finished.