deep-high-resolution-net.pytorch icon indicating copy to clipboard operation
deep-high-resolution-net.pytorch copied to clipboard

Inferencing HRNet with Volleyball dataset

Open AvivSham opened this issue 3 years ago • 0 comments

I'm trying to inference HRNet with Volleyball dataset to extract pose keypoints. Since the dataset is annotated with manual BBs. I did slight modifications for the demo/inference.py file and the code runs but the output is off. image These are the key points of the player in the most left position. I thought the problem might be with the BBs but they look spot on! image

My guess is the problem lays in the center/scale function. But I was not able to find it. How can I solve it? Does anyone try to inference Volleyball dataset as well?

Here is my main function:

def main(cfg_path: str, annot: str, data_path: Path, device: torch.device) -> None:
    # transformation
    pose_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    # read BB annotation file
    with open(annot, "rb") as file:
        annot_data = pickle.load(file)

    # read config file
    with open(cfg_path, "r") as file:
        cfg = yaml.load(file, Loader=yaml.FullLoader)
    pose_model = PoseHighResolutionNet(cfg)
    if len(cfg['MODEL']['PRETRAINED']) != 0:
        logging.info('loading model from {}'.format(cfg['MODEL']['PRETRAINED']))
        pose_model.load_state_dict(torch.load(cfg['MODEL']['PRETRAINED']), strict=False)
    else:
        logging.info('expected model defined in config at TEST.MODEL_FILE')
    logging.info("Pose model loaded successfully")
    
    pose_model.to(device)
    pose_model.eval()
    image_list = list(data_path.glob("**/*.jpg"))
    image_iter = trange(len(image_list))
    for im in image_iter:
        image_path = image_list[im]
        video_id = int(image_path.parent.parent.stem)
        frame_id = int(image_path.parent.stem)
        image_id = int(image_path.stem)
       
        if image_id not in annot_data[(video_id, frame_id)].keys():
            continue
        image = np.array(Image.open(image_path).convert("RGB"))
        image_debug = image.copy()

        pred_boxes = annot_data[(video_id, frame_id)][image_id]

        # pose estimation : for multiple people
        centers = []
        scales = []

        for i, box in enumerate(pred_boxes):

            # For Debugging
            image_debug = cv2.rectangle(np.array(image_debug), (box[0], box[1]), (box[2], box[3]), (0, 255, 0),
                                        thickness=3)

            if not isinstance(box[0], tuple):
                box = [(box[0], box[1]), (box[2], box[3])]

            center, scale = box_to_center_scale(box, cfg['MODEL']['IMAGE_SIZE'][0], cfg['MODEL']['IMAGE_SIZE'][1])
            centers.append(center)
            scales.append(scale)

        pose_preds = get_pose_estimation_prediction(cfg, pose_model, image, centers, scales,
                                                    transform=pose_transform, device=device)
        
        for coords in pose_preds:
            # Draw each point on image
            for coord in coords:
                x_coord, y_coord = int(coord[0]), int(coord[1])
                cv2.circle(image_debug, (x_coord, y_coord), 4, (255, 0, 0), 2)

AvivSham avatar Mar 14 '21 14:03 AvivSham