MaskGIT-pytorch
MaskGIT-pytorch copied to clipboard
Isn't loss only supposed to be calculated on masked tokens?
In the training loop we have:
imgs = imgs.to(device=args.device)
logits, target = self.model(imgs)
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
loss.backward()
However, the output of the transformer is:
_, z_indices = self.encode_to_z(x)
.
.
.
a_indices = mask * z_indices + (~mask) * masked_indices
a_indices = torch.cat((sos_tokens, a_indices), dim=1)
target = torch.cat((sos_tokens, z_indices), dim=1)
logits = self.transformer(a_indices)
return logits, target
which means the returned target is the original unmasked image tokens.
The MaskGIT paper seems to suggest that loss was only calculated on the masked tokens
I've attempted both strategies for a simple MaskGIT on CIFAR10 but the generation quality seems to still be bad. There are tricks that the authors are not telling us in the paper for their training scheme
I have the same issue. Why loss was calculated on all tokens?
@Lamikins I believe the training issues come from an error in the masking formula. I've ammended the error: https://github.com/dome272/MaskGIT-pytorch/pull/16.
@xuesongnie
@EmaadKhwaja return logits[~mask], target[~mask]
seems a bit problematic, we should calculate masked token loss return logits[mask], target[mask]
@xuesongnie it's because the mask calculated is applied to the wrong values. The other option would be to do r = math.floor(1-self.gamma(np.random.uniform()) * z_indices.shape[1])
, but I don't like that because it's different from how the formula appears in the paper
@xuesongnie it's because the mask calculated is applied to the wrong values. The other option would be to do
r = math.floor(1-self.gamma(np.random.uniform()) * z_indices.shape[1])
, but I don't like that because it's different from how the formula appears in the paper
Hi, bro. I find that poor performance after modifying return logits[mask], target[mask]
. It is weird. I guess the embedding layer also needs to train the corresponding unmasked token.