ijepa
ijepa copied to clipboard
Bias in Multiblock Mask Collator
I've been sampling the multi-block mask collator and plotting the masks to understand how they look, and believe I've found a bias that may have significant impact on the training of any models using this class.
The following patterns are consistently shown for batch sizes >= 128, and convey that many patches central in the image are never masked by enc_masks. Note this behaviour only occurs for allow_overlap=False.
Here are four examples I've sampled using the code below, with no cherry picking. Each image is a 128 sized batch of enc masks, generated using the default arguments for the multi-block mask collator. Each pixel represents a patch, and is white iff that patch is included in any of the masks in its batch. Repro code below.
import torch
from src.mask import MaskCollator
import matplotlib.pyplot as plt
collator = MaskCollator()
batch = [torch.randn(3, 224, 224) for _ in range(1024)]
batch = collator(batch)
batch, enc_masks, pred_masks = batch
def display_mask(mask):
# mask is a tensor of indices from 0 to 195
# can be individual mask, or multiple.
# display a 14x14 grid, where each cell is on if the corresponding index is in the mask
grid = torch.zeros(14,14)
for i in range(196):
grid[i // 14, i % 14] = 1 if i in mask else 0
plt.imshow(grid, cmap='gray')
plt.show()
# change second index from ':' to integer to visualise individual masks
display_mask(enc_masks[0][:])