vggt
vggt copied to clipboard
How to avoid attention dilution and camera head compression in global attention pass when batch process
For VGGT inference, with batch size and overlapping frames, how to avoid attention dilution and camera head compression in global attention pass when batch process?
tokens_per_frame = [t for t in tokens_per_frame if t is not None]
aggregated_tokens = torch.cat(tokens_per_frame, dim=0).unsqueeze(0)
aggregated_tokens = global_norm_layer(aggregated_tokens)
pos = None
if getattr(model.aggregator, "position_getter", None) is not None:
try:
total_patches = aggregated_tokens.shape[1]
pos = model.aggregator.position_getter(1, H, W, device=aggregated_tokens.device)
pos = pos.to(dtype)
pos = pos.repeat(1, n_total, 1)[:, :total_patches, :]
if torch.isnan(pos).any():
print("NaN detected in positional encodings; using zeros")
pos = torch.zeros_like(pos)
except Exception as e:
print(f"Position getter failed: {e}")
pos = None
aggregated_tokens = aggregated_tokens.to(dtype)
original_tokens = aggregated_tokens.clone() # Store 2048D tokens
aggregated_tokens = global_proj_in(aggregated_tokens)
def run_block(block, tokens, pos_arg, idx):
return block(tokens, pos=pos_arg, global_merging=idx)
with torch.inference_mode():
for block_idx in range(model.aggregator.depth):
if block_idx >= args.merging:
aggregated_tokens = run_block(model.aggregator.global_blocks[block_idx], aggregated_tokens, pos, block_idx)
max_amplification = aggregated_tokens.abs().max()
if max_amplification > 10.0: # Threshold from logs
print(f"Attention dilution detected: Token max amplified to {max_amplification:.2f} (potential instability in later blocks)")
pre_head_max = aggregated_tokens.abs().max()
aggregated_tokens = original_tokens # Restore 2048D for camera_head
post_projection_max = aggregated_tokens.abs().max() # After proj to 2048
if post_projection_max < pre_head_max * 0.5: # Compression threshold
print(f"Projection/camera head compression: Max reduced from {pre_head_max:.2f} to {post_projection_max:.2f} (may dilute pose_enc)")
with torch.inference_mode():
num_patches = aggregated_tokens.shape[1] // n_total # Includes patch tokens + 5 special tokens
aggregated_tokens = aggregated_tokens.view(1, n_total, num_patches, -1) # [1, n_total, num_patches+5, 2048]
predictions_global = model.camera_head([aggregated_tokens]) # Pass as list for consistency with VGGT.forward
global_pose_enc = predictions_global[-1] # Shape: [1, 12, 9]
global_focal_mean = global_pose_enc[..., 7:9].mean().item() # Mean of fov_h and fov_w
if global_focal_mean < batch_focal_mean * 0.8: # Warn if global FoV is significantly smaller
print(f"FoV dilution warning: Global focal mean {global_focal_mean:.4f} < 80% of batch {batch_focal_mean:.4f} (expect inflated intrinsics)")
extrinsic_all, intrinsic_all = pose_encoding_to_extri_intri(
predictions_global[-1], (vggt_fixed_resolution_width, vggt_fixed_resolution_height)
)
global_pose_enc = predictions_global[-1] # Shape: [1, 12, 9]
global_focal_mean = global_pose_enc[..., 7:9].mean().item() # Mean of fov_h and fov_w
if global_focal_mean < batch_focal_mean * 0.8: # From components logs
print(f"FoV dilution warning: Global focal mean {global_focal_mean:.4f} < batch {batch_focal_mean:.4f} (expect inflated intrinsics)")
extrinsic_all = extrinsic_all[0].detach().float().cpu().numpy()
intrinsic_all = intrinsic_all[0].detach().float().cpu().numpy()
frame_mask = np.ones(n_total, dtype=bool)
for i in range(1, n_batches):
start = i * step
end = min(start + overlap, n_total)
if start < n_total:
frame_mask[start:end] = False
kept_frame_indices = np.nonzero(frame_mask)[0].tolist()
extrinsic_kept = extrinsic_all[kept_frame_indices]
intrinsic_kept = intrinsic_all[kept_frame_indices]
for i in range(len(intrinsic_kept)):
fx = intrinsic_kept[i, 0, 0].item()
fy = intrinsic_kept[i, 1, 1].item()
if not (np.isfinite(fx) and np.isfinite(fy)):
print(f" - Frame {i} has invalid intrinsics (fx={fx}, fy={fy}), replacing with defaults")
intrinsic_kept[i] = intrinsic[i].clone().detach() if 'intrinsic' in locals() and intrinsic is not None and intrinsic[i].is_cuda else torch.tensor(
[[1000.0, 0.0, 196.0], [0.0, 1000.0, 259.0], [0.0, 0.0, 1.0]], device=intrinsic_kept.device, dtype=torch.float32)
depth_maps = []
for frame_idx in kept_frame_indices:
depth = depth_tensors_per_frame[frame_idx]
if depth is None:
depth = torch.full((H, W), float('nan'), device='cpu')
else:
depth = depth.to(torch.float32).cpu()
if depth.ndim == 3 and depth.shape[-1] == 1:
depth = depth.squeeze(-1)
depth_maps.append(depth)
depth_maps = torch.stack(depth_maps, dim=0)
depth_maps = depth_maps.unsqueeze(-1)
points3d_per_kept_frame = unproject_depth_map_to_point_map(
depth_maps.numpy(),
extrinsic_kept,
intrinsic_kept
)