deep-high-resolution-net.pytorch
deep-high-resolution-net.pytorch copied to clipboard
Inferencing HRNet with Volleyball dataset
I'm trying to inference HRNet with Volleyball dataset to extract pose keypoints. Since the dataset is annotated with manual BBs. I did slight modifications for the demo/inference.py
file and the code runs but the output is off.
These are the key points of the player in the most left position. I thought the problem might be with the BBs but they look spot on!
My guess is the problem lays in the center/scale function. But I was not able to find it. How can I solve it? Does anyone try to inference Volleyball dataset as well?
Here is my main function:
def main(cfg_path: str, annot: str, data_path: Path, device: torch.device) -> None:
# transformation
pose_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# read BB annotation file
with open(annot, "rb") as file:
annot_data = pickle.load(file)
# read config file
with open(cfg_path, "r") as file:
cfg = yaml.load(file, Loader=yaml.FullLoader)
pose_model = PoseHighResolutionNet(cfg)
if len(cfg['MODEL']['PRETRAINED']) != 0:
logging.info('loading model from {}'.format(cfg['MODEL']['PRETRAINED']))
pose_model.load_state_dict(torch.load(cfg['MODEL']['PRETRAINED']), strict=False)
else:
logging.info('expected model defined in config at TEST.MODEL_FILE')
logging.info("Pose model loaded successfully")
pose_model.to(device)
pose_model.eval()
image_list = list(data_path.glob("**/*.jpg"))
image_iter = trange(len(image_list))
for im in image_iter:
image_path = image_list[im]
video_id = int(image_path.parent.parent.stem)
frame_id = int(image_path.parent.stem)
image_id = int(image_path.stem)
if image_id not in annot_data[(video_id, frame_id)].keys():
continue
image = np.array(Image.open(image_path).convert("RGB"))
image_debug = image.copy()
pred_boxes = annot_data[(video_id, frame_id)][image_id]
# pose estimation : for multiple people
centers = []
scales = []
for i, box in enumerate(pred_boxes):
# For Debugging
image_debug = cv2.rectangle(np.array(image_debug), (box[0], box[1]), (box[2], box[3]), (0, 255, 0),
thickness=3)
if not isinstance(box[0], tuple):
box = [(box[0], box[1]), (box[2], box[3])]
center, scale = box_to_center_scale(box, cfg['MODEL']['IMAGE_SIZE'][0], cfg['MODEL']['IMAGE_SIZE'][1])
centers.append(center)
scales.append(scale)
pose_preds = get_pose_estimation_prediction(cfg, pose_model, image, centers, scales,
transform=pose_transform, device=device)
for coords in pose_preds:
# Draw each point on image
for coord in coords:
x_coord, y_coord = int(coord[0]), int(coord[1])
cv2.circle(image_debug, (x_coord, y_coord), 4, (255, 0, 0), 2)