ru-dalle icon indicating copy to clipboard operation
ru-dalle copied to clipboard

Confusion about loss calculation

Open distillation-dcf opened this issue 1 year ago • 0 comments

Hi!

In forward() function of model.py, text loss and image loss is computed by

labels = torch.cat((text[:, 1:], image_input_ids), dim=1).contiguous().long()  # shape: (bs, 127+1024=1151)
loss_text = F.cross_entropy(
    text_logits,
    labels[:, :self.text_seq_length])  # shape: (bs, 128)
loss_img = F.cross_entropy(
    image_logits,
    labels[:, self.text_seq_length:])  # shape: (bs, 1023)

Here text[:, 1:] should be removal of the first [BOS] text token label, then there are only 128-1=127 text tokens left in labels. But in CE loss, text logits with seq_len=128 and labels[:, :self.text_seq_length]) # shape: (bs, 128) come to calculate the text loss. I guess that the very first image token after all text tokens are taken into text loss computation by mistake.

Am I understanding the code correctly? Will the text token length in CE loss calculation affect the training process?

distillation-dcf avatar May 12 '23 08:05 distillation-dcf