LLMLingua
LLMLingua copied to clipboard
Remove Duplicate Declaration of Loss Function
I noticed there is an unnecessary duplicate declaration of loss_fct here.
Relevant code:
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
shift_logits = response.logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., past_length + 1 : end].contiguous()
# Flatten the tokens
active = (attention_mask[:, past_length:end] == 1)[..., :-1].view(-1)
active_logits = shift_logits.view(-1, shift_logits.size(-1))[active]
active_labels = shift_labels.view(-1)[active]
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
loss = loss_fct(active_logits, active_labels)
As you can see, loss_fct is not used before it is declared for a second time, therefore it is safe to remove the first declaration.