How to track whole pixels?
Hi! Thank you for this amazing work!
I'm trying to use the track head to track all pixels from the first image through a sequence of 49 subsequent images. When I attempted this by passing (518x518, 2) grid-like query points to model.track_head(), I encountered the CUDA out of memory error. Is there currently an efficient way to implement whole-image pixel tracking in the codebase? Any guidance would be greatly appreciated!
My code snippet:
def run_VGGT(model, images, dtype, resolution=518):
# images: [B, 3, H, W]
assert len(images.shape) == 4
assert images.shape[1] == 3
# hard-coded to use 518 for VGGT
images = F.interpolate(images, size=(resolution, resolution), mode="bilinear", align_corners=False)
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=dtype):
images = images[None] # add batch dimension
aggregated_tokens_list, ps_idx = model.aggregator(images)
x_coords = torch.arange(resolution).view(1, -1).expand(resolution, resolution)
y_coords = torch.arange(resolution).view(-1, 1).expand(resolution, resolution)
coords = torch.stack([x_coords, y_coords], dim=-1)
query_points = coords.reshape((-1, 2)).to(images.device)
track_list, vis_score, conf_score = model.track_head(aggregated_tokens_list, images, ps_idx, query_points=query_points[None])
Hi,
Honestly I think this might be too heavy by now. Can you have a look at https://github.com/facebookresearch/co-tracker?
Got it! Thank you for the suggestion!
Hi! I have worked with similar settings for tracking all pixels in a sequence of images and have met the same problem. Have you solved the problem now? Any suggestions would be helpful! @dadadadawjb
@ShirleyQSY Hi! If you still want to use VGGT for dense tracking, one possible approach I took is to process the pixels block-wise.
def run_VGGT(model, images, dtype, resolution=518):
# images: [B, 3, H, W]
assert len(images.shape) == 4
assert images.shape[1] == 3
# hard-coded to use 518 for VGGT
images = F.interpolate(images, size=(resolution, resolution), mode="bilinear", align_corners=False)
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=dtype):
images = images[None] # add batch dimension
aggregated_tokens_list, ps_idx = model.aggregator(images)
x_coords = torch.arange(resolution).view(1, -1).expand(resolution, resolution)
y_coords = torch.arange(resolution).view(-1, 1).expand(resolution, resolution)
coords = torch.stack([x_coords, y_coords], dim=-1)
query_points = coords.reshape((-1, 2))
tracks = []
block_size = 512
for i in tqdm.trange(0, query_points.shape[0], block_size):
query_points_block = query_points[i:i + block_size].to(images.device)
track_list, vis_score, conf_score = model.track_head(aggregated_tokens_list, images, ps_idx, query_points=query_points_block[None])
tracks.append(track_list[-1][0].cpu().numpy()) # take the last iteration's track
tracks = np.concatenate(tracks, axis=1)
return tracks
Or I found AllTracker also works quite well for tracking all pixels. Hope this helps!
@dadadadawjb Thank you so much! I will try these methods. Hope they will work lol.