Tracking head running out of memory
Hi. I am trying to run point tracking with 60 input images. However, with a 32GB V100 GPU, I am having out of memory issues.
Even with only 10 query points, it seems to run out of memory in base_track_predictor.py at the stage: fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2)) where fmaps is torch.Size([1, 60, 128, 259, 259]),
or in vggt/heads/track_modules/blocks.py at: corrs = compute_corr_level(fmap1, fmap2s, C)
Ideally, I want to track around 2000 query points through all images. Is it okay to split the forward pass of tracking into chunks of images? Something like this:
chnk = 10
tracks = []
for i in range(images.shape[1]//chnk):
images_chnk = images[:,i*chnk:(i+1)*chnk]
track_list, vis_score, conf_score = model.track_head(aggregated_tokens_list, images_chnk, ps_idx, query_points)
tracks.append(track_list[-1])
tracks = torch.cat(tracks, 1)
But this also slows it down significantly. Is there a better way to do it?
Hi.
I tried to reproduce this but it seems working correctly for me. I can dig into this further though chunking may give you something good at the stage.
Another suggestion is, would you mind trying to delete those unused tensors and add torch.cuda.empty_cache after the deletion? This should save remarkable memory in most cases.
Best, Jianyuan
Here's some code that runs out of memory for me:
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
images = torch.rand(([60, 3, 518, 518])).to(device)
H, W = images.shape[-2:]
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=dtype):
predictions = {}
images = images[None] # add batch dimension
aggregated_tokens_list, ps_idx = model.aggregator(images)
predictions["images"] = images
# Predict pose
predictions["pose_enc"] = model.camera_head(aggregated_tokens_list)[-1]
# Predict Depth Maps
depth_map, depth_conf = model.depth_head(aggregated_tokens_list, images, ps_idx)
predictions["depth"] = depth_map
predictions["depth_conf"] = depth_conf
# Predict Point Maps
point_map, point_conf = model.point_head(aggregated_tokens_list, images, ps_idx)
predictions["world_points"] = point_map # B, N, H, W, 3
predictions["world_points_conf"] = point_conf # B, N, H, W
extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
predictions["extrinsic"] = extrinsic # (1, N, 3, 4)
predictions["intrinsic"] = intrinsic # (1, N, 3, 3)
# Convert tensors to numpy
for key in predictions.keys():
if isinstance(predictions[key], torch.Tensor):
predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension
# Clean up
del predictions
gc.collect()
torch.cuda.empty_cache()
idx = np.random.choice(H*W, 2000, replace=False)
row_idx, col_idx = np.unravel_index(idx, (H,W))
query_points = np.stack((col_idx, row_idx), 1)
query_points = torch.from_numpy(query_points).type(dtype).to(device)
track_list, vis_score, conf_score = model.track_head(aggregated_tokens_list, images, ps_idx, query_points=query_points[None])
Confirming I have the same issue on the similar code, on first loop it passes, but on a second I get OOM:
with torch.no_grad() and torch.amp.autocast('cuda', dtype=self.dtype):
for i in range(0, len(frames), self.batch_size):
# batch = frames[i:i + self.batch_size].to(self.device)
# always include first element
first_frame= frames[0][None, ...]
batch = frames[i:i + self.batch_size - 1]
# add first frame to the batch
batch = np.concatenate((first_frame, batch), axis=0)
output = self.infer_batch(batch, near, far, num_denoising_steps, guidance_scale, window_size, overlap, **kwargs)
depth = output["depth"].squeeze(0) # [B,H,W,1]
depth = (depth - depth.min()) / (depth.max() - depth.min())
depth *= 3900
depth[depth < 1e-5] = 1e-5
depth = 10000.0 / depth
depth = depth.clip(near, far)
depths.append(depth.cpu().detach())
# clean up GPU memory
del output
torch.cuda.empty_cache()
gc.collect()
depths = torch.cat(depths, dim=0)