vggt icon indicating copy to clipboard operation
vggt copied to clipboard

How to track whole pixels?

Open dadadadawjb opened this issue 6 months ago • 2 comments

Hi! Thank you for this amazing work!

I'm trying to use the track head to track all pixels from the first image through a sequence of 49 subsequent images. When I attempted this by passing (518x518, 2) grid-like query points to model.track_head(), I encountered the CUDA out of memory error. Is there currently an efficient way to implement whole-image pixel tracking in the codebase? Any guidance would be greatly appreciated!

My code snippet:

def run_VGGT(model, images, dtype, resolution=518):
    # images: [B, 3, H, W]

    assert len(images.shape) == 4
    assert images.shape[1] == 3

    # hard-coded to use 518 for VGGT
    images = F.interpolate(images, size=(resolution, resolution), mode="bilinear", align_corners=False)

    with torch.no_grad():
        with torch.cuda.amp.autocast(dtype=dtype):
            images = images[None]  # add batch dimension
            aggregated_tokens_list, ps_idx = model.aggregator(images)

        x_coords = torch.arange(resolution).view(1, -1).expand(resolution, resolution)
        y_coords = torch.arange(resolution).view(-1, 1).expand(resolution, resolution)
        coords = torch.stack([x_coords, y_coords], dim=-1)
        query_points = coords.reshape((-1, 2)).to(images.device)
        track_list, vis_score, conf_score = model.track_head(aggregated_tokens_list, images, ps_idx, query_points=query_points[None])

dadadadawjb avatar Jun 23 '25 13:06 dadadadawjb

Hi,

Honestly I think this might be too heavy by now. Can you have a look at https://github.com/facebookresearch/co-tracker?

jytime avatar Jun 23 '25 17:06 jytime

Got it! Thank you for the suggestion!

dadadadawjb avatar Jun 24 '25 02:06 dadadadawjb

Hi! I have worked with similar settings for tracking all pixels in a sequence of images and have met the same problem. Have you solved the problem now? Any suggestions would be helpful! @dadadadawjb

ShirleyQSY avatar Nov 24 '25 07:11 ShirleyQSY

@ShirleyQSY Hi! If you still want to use VGGT for dense tracking, one possible approach I took is to process the pixels block-wise.

def run_VGGT(model, images, dtype, resolution=518):
    # images: [B, 3, H, W]

    assert len(images.shape) == 4
    assert images.shape[1] == 3

    # hard-coded to use 518 for VGGT
    images = F.interpolate(images, size=(resolution, resolution), mode="bilinear", align_corners=False)

    with torch.no_grad():
        with torch.cuda.amp.autocast(dtype=dtype):
            images = images[None]  # add batch dimension
            aggregated_tokens_list, ps_idx = model.aggregator(images)

        x_coords = torch.arange(resolution).view(1, -1).expand(resolution, resolution)
        y_coords = torch.arange(resolution).view(-1, 1).expand(resolution, resolution)
        coords = torch.stack([x_coords, y_coords], dim=-1)
        query_points = coords.reshape((-1, 2))
        tracks = []
        block_size = 512
        for i in tqdm.trange(0, query_points.shape[0], block_size):
            query_points_block = query_points[i:i + block_size].to(images.device)
            track_list, vis_score, conf_score = model.track_head(aggregated_tokens_list, images, ps_idx, query_points=query_points_block[None])
            tracks.append(track_list[-1][0].cpu().numpy())  # take the last iteration's track
        tracks = np.concatenate(tracks, axis=1)

    return tracks

Or I found AllTracker also works quite well for tracking all pixels. Hope this helps!

dadadadawjb avatar Nov 24 '25 08:11 dadadadawjb

@dadadadawjb Thank you so much! I will try these methods. Hope they will work lol.

ShirleyQSY avatar Nov 24 '25 08:11 ShirleyQSY