MaskGIT-pytorch icon indicating copy to clipboard operation
MaskGIT-pytorch copied to clipboard

Isn't loss only supposed to be calculated on masked tokens?

Open EmaadKhwaja opened this issue 2 years ago • 6 comments

In the training loop we have:

imgs = imgs.to(device=args.device)
logits, target = self.model(imgs)
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
loss.backward()

However, the output of the transformer is:

  _, z_indices = self.encode_to_z(x)
.
.
.
  a_indices = mask * z_indices + (~mask) * masked_indices

  a_indices = torch.cat((sos_tokens, a_indices), dim=1)

  target = torch.cat((sos_tokens, z_indices), dim=1)

  logits = self.transformer(a_indices)

  return logits, target

which means the returned target is the original unmasked image tokens.

The MaskGIT paper seems to suggest that loss was only calculated on the masked tokens

image

EmaadKhwaja avatar Nov 08 '22 23:11 EmaadKhwaja

I've attempted both strategies for a simple MaskGIT on CIFAR10 but the generation quality seems to still be bad. There are tricks that the authors are not telling us in the paper for their training scheme

darius-lam avatar Jan 05 '23 04:01 darius-lam

I have the same issue. Why loss was calculated on all tokens?

xuesongnie avatar Aug 30 '23 04:08 xuesongnie

@Lamikins I believe the training issues come from an error in the masking formula. I've ammended the error: https://github.com/dome272/MaskGIT-pytorch/pull/16.

@xuesongnie

EmaadKhwaja avatar Sep 03 '23 15:09 EmaadKhwaja

@EmaadKhwaja return logits[~mask], target[~mask] seems a bit problematic, we should calculate masked token loss return logits[mask], target[mask]

xuesongnie avatar Sep 03 '23 16:09 xuesongnie

@xuesongnie it's because the mask calculated is applied to the wrong values. The other option would be to do r = math.floor(1-self.gamma(np.random.uniform()) * z_indices.shape[1]), but I don't like that because it's different from how the formula appears in the paper

EmaadKhwaja avatar Sep 03 '23 16:09 EmaadKhwaja

@xuesongnie it's because the mask calculated is applied to the wrong values. The other option would be to do r = math.floor(1-self.gamma(np.random.uniform()) * z_indices.shape[1]), but I don't like that because it's different from how the formula appears in the paper

Hi, bro. I find that poor performance after modifying return logits[mask], target[mask]. It is weird. I guess the embedding layer also needs to train the corresponding unmasked token.

xuesongnie avatar Sep 16 '23 15:09 xuesongnie