JARVIS-HybridNet
JARVIS-HybridNet copied to clipboard
Predict2D over batches
Hey Timo,
Do you have code to run predict2D over batches?
The only code I can find loops over frames.
https://github.com/JARVIS-MoCap/JARVIS-HybridNet/blob/master/jarvis/prediction/predict2D.py#L91
Thanks!
If there isn't already something I wrote this
def get_2D_keypoints_batch(frames):
"""
Given a batch of images (frames), return the 2D keypoints and confidence scores for each image.
Uses the jarvisPredictor object to perform keypoint detection on the images.
Args:
frames (numpy.ndarray): A batch of images with shape (batch_size, height, width, channels).
Returns:
tuple: A tuple containing the 2D keypoints and confidence scores for each image in the batch.
The keypoints have shape (batch_size, num_keypoints, 2) and the confidence scores have shape
(batch_size, num_keypoints).
"""
# Convert the input frames to a PyTorch tensor and perform some pre-processing
imgs = torch.from_numpy(frames).cuda().float().permute(0, 3, 1, 2) / 255.0
# Get the size of the input images
img_size = torch.tensor([imgs.shape[3], imgs.shape[2]], device=torch.device("cuda"))
# Compute the downsampling scale for the center detection
downsampling_scale = torch.tensor(
[
imgs.shape[3] / float(jarvisPredictor.center_detect_img_size),
imgs.shape[2] / float(jarvisPredictor.center_detect_img_size),
],
device=torch.device("cuda"),
).float()
# Resize the input images to the size expected by the center detection network
imgs_resized = transforms.functional.resize(
imgs,
[
jarvisPredictor.center_detect_img_size,
jarvisPredictor.center_detect_img_size,
],
)
# Normalize the resized images
imgs_resized = (
imgs_resized - jarvisPredictor.transform_mean
) / jarvisPredictor.transform_std
# Run the center detection network on the resized images
outputs = jarvisPredictor.centerDetect(imgs_resized)
# Get the heatmaps from the center detection network and convert them to keypoints
heatmaps_gpu = outputs[1].view(outputs[1].shape[0], outputs[1].shape[1], -1)
m = heatmaps_gpu.argmax(2).view(heatmaps_gpu.shape[0], heatmaps_gpu.shape[1], 1)
preds = torch.cat((m % outputs[1].shape[2], m // outputs[1].shape[3]), dim=2)
maxvals = heatmaps_gpu.gather(2, m)
num_cams_detect = torch.numel(maxvals[maxvals > 50])
maxvals = maxvals / 255.0
# Convert the keypoints from center detection to image coordinates
centerHMs = (
torch.cat((m % outputs[1].shape[2], m // outputs[1].shape[3]), dim=2).squeeze()
* downsampling_scale
* 2
)
centerHMs[:, 0] = torch.clamp(
centerHMs[:, 0], jarvisPredictor.bbox_hw, img_size[0] - jarvisPredictor.bbox_hw
)
centerHMs[:, 1] = torch.clamp(
centerHMs[:, 1], jarvisPredictor.bbox_hw, img_size[1] - jarvisPredictor.bbox_hw
)
# Crop the input images to the bounding boxes around the keypoints
imgs_cropped = torch.zeros(
(
batch_size,
3,
jarvisPredictor.bounding_box_size,
jarvisPredictor.bounding_box_size,
),
device=torch.device("cuda"),
)
centerHMs = centerHMs.int().cpu().numpy()
for i in range(batch_size):
imgs_cropped[i] = imgs[
i,
:,
centerHMs[i, 1]
- jarvisPredictor.bbox_hw : centerHMs[i, 1]
+ jarvisPredictor.bbox_hw,
centerHMs[i, 0]
- jarvisPredictor.bbox_hw : centerHMs[i, 0]
+ jarvisPredictor.bbox_hw,
]
# Normalize the cropped images
imgs_cropped = (
imgs_cropped - jarvisPredictor.transform_mean
) / jarvisPredictor.transform_std
# Run the keypoint detection network on the cropped images
outputs = jarvisPredictor.keypointDetect(imgs_cropped)
# Get the heatmaps from the keypoint detection network and convert them to keypoints
heatmaps = outputs[1].view(outputs[1].shape[0], outputs[1].shape[1], -1)
m = heatmaps.argmax(2).view(heatmaps.shape[0], heatmaps.shape[1], 1)
points2D = (
torch.cat((m % outputs[1].shape[2], m // outputs[1].shape[3]), dim=2).squeeze()
* 2
)
confidences = heatmaps.gather(2, m).squeeze()
confidences = torch.clamp(confidences, max=255.0) / 255.0
points2D = points2D.cpu().numpy() +np.expand_dims(centerHMs,1)-jarvisPredictor.bbox_hw
# Convert the PyTorch tensors to numpy arrays and return as a tuple
return points2D, confidences.cpu().numpy()
Hi Tim, you're right there currently is no code to perform batched predictions, so thank you very much for sharing your implementation! I'll add your implementation to the repo if you don't mind. Alternatively you can also open a pull request, that way you get some credit for your contribution, just let me know what you prefer :)
Hi, I tested this batch processing for 2D video prediction and it generates the csv files correctly. However, the visualization module doesn't work since the info.yaml now contains only the path to the recordings instead of the recordings themselves. Can the visualize2D be modified to account for this?