vggt icon indicating copy to clipboard operation
vggt copied to clipboard

Tracking head running out of memory

Open ricshaw opened this issue 9 months ago • 3 comments

Hi. I am trying to run point tracking with 60 input images. However, with a 32GB V100 GPU, I am having out of memory issues.

Even with only 10 query points, it seems to run out of memory in base_track_predictor.py at the stage: fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2)) where fmaps is torch.Size([1, 60, 128, 259, 259]),

or in vggt/heads/track_modules/blocks.py at: corrs = compute_corr_level(fmap1, fmap2s, C)

Ideally, I want to track around 2000 query points through all images. Is it okay to split the forward pass of tracking into chunks of images? Something like this:

chnk = 10
tracks = []
for i in range(images.shape[1]//chnk):
    images_chnk = images[:,i*chnk:(i+1)*chnk]
    track_list, vis_score, conf_score = model.track_head(aggregated_tokens_list, images_chnk, ps_idx, query_points)
    tracks.append(track_list[-1])
tracks = torch.cat(tracks, 1)

But this also slows it down significantly. Is there a better way to do it?

ricshaw avatar Mar 26 '25 06:03 ricshaw

Hi.

I tried to reproduce this but it seems working correctly for me. I can dig into this further though chunking may give you something good at the stage.

Another suggestion is, would you mind trying to delete those unused tensors and add torch.cuda.empty_cache after the deletion? This should save remarkable memory in most cases.

Best, Jianyuan

jytime avatar Mar 26 '25 13:03 jytime

Here's some code that runs out of memory for me:

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16

images = torch.rand(([60, 3, 518, 518])).to(device)
H, W = images.shape[-2:]

with torch.no_grad():
    with torch.cuda.amp.autocast(dtype=dtype):
        predictions = {}
        images = images[None]  # add batch dimension
        aggregated_tokens_list, ps_idx = model.aggregator(images)
        predictions["images"] = images
        # Predict pose
        predictions["pose_enc"] = model.camera_head(aggregated_tokens_list)[-1]
        # Predict Depth Maps
        depth_map, depth_conf = model.depth_head(aggregated_tokens_list, images, ps_idx)
        predictions["depth"] = depth_map
        predictions["depth_conf"] = depth_conf
        # Predict Point Maps
        point_map, point_conf = model.point_head(aggregated_tokens_list, images, ps_idx)
        predictions["world_points"] = point_map # B, N, H, W, 3
        predictions["world_points_conf"] = point_conf # B, N, H, W

        extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
        predictions["extrinsic"] = extrinsic  # (1, N, 3, 4)
        predictions["intrinsic"] = intrinsic  # (1, N, 3, 3)

        # Convert tensors to numpy
        for key in predictions.keys():
            if isinstance(predictions[key], torch.Tensor):
                predictions[key] = predictions[key].cpu().numpy().squeeze(0)  # remove batch dimension

        # Clean up
        del predictions
        gc.collect()
        torch.cuda.empty_cache()

        idx = np.random.choice(H*W, 2000, replace=False)
        row_idx, col_idx = np.unravel_index(idx, (H,W))
        query_points = np.stack((col_idx, row_idx), 1)
        query_points = torch.from_numpy(query_points).type(dtype).to(device)
        track_list, vis_score, conf_score = model.track_head(aggregated_tokens_list, images, ps_idx, query_points=query_points[None])

ricshaw avatar Mar 26 '25 14:03 ricshaw

Confirming I have the same issue on the similar code, on first loop it passes, but on a second I get OOM:

        with torch.no_grad() and torch.amp.autocast('cuda', dtype=self.dtype):
            for i in range(0, len(frames), self.batch_size):
                # batch = frames[i:i + self.batch_size].to(self.device)
                # always include first element
                first_frame= frames[0][None, ...]


                batch = frames[i:i + self.batch_size - 1]
                # add first frame to the batch
                batch = np.concatenate((first_frame, batch), axis=0)
                
                output = self.infer_batch(batch, near, far, num_denoising_steps, guidance_scale, window_size, overlap, **kwargs)

                depth = output["depth"].squeeze(0) # [B,H,W,1]

                depth = (depth - depth.min()) / (depth.max() - depth.min())
                depth *= 3900
                depth[depth < 1e-5] = 1e-5
                depth = 10000.0 / depth
                depth = depth.clip(near, far)

                
                depths.append(depth.cpu().detach())
                # clean up GPU memory
                del output
                torch.cuda.empty_cache()
                gc.collect()

        depths = torch.cat(depths, dim=0)

dsvilarkovic avatar Mar 26 '25 15:03 dsvilarkovic