ru-dalle
ru-dalle copied to clipboard
Confusion about loss calculation
Hi!
In forward()
function of model.py
, text loss and image loss is computed by
labels = torch.cat((text[:, 1:], image_input_ids), dim=1).contiguous().long() # shape: (bs, 127+1024=1151)
loss_text = F.cross_entropy(
text_logits,
labels[:, :self.text_seq_length]) # shape: (bs, 128)
loss_img = F.cross_entropy(
image_logits,
labels[:, self.text_seq_length:]) # shape: (bs, 1023)
Here text[:, 1:]
should be removal of the first [BOS]
text token label, then there are only 128-1=127
text tokens left in labels. But in CE loss, text logits
with seq_len=128
and labels[:, :self.text_seq_length]) # shape: (bs, 128)
come to calculate the text loss. I guess that the very first image token after all text tokens are taken into text loss computation by mistake.
Am I understanding the code correctly? Will the text token length in CE loss calculation affect the training process?