LLMLingua icon indicating copy to clipboard operation
LLMLingua copied to clipboard

Remove Duplicate Declaration of Loss Function

Open Speuce opened this issue 2 years ago • 0 comments

I noticed there is an unnecessary duplicate declaration of loss_fct here.

Relevant code:

        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        shift_logits = response.logits[..., :-1, :].contiguous()
        shift_labels = input_ids[..., past_length + 1 : end].contiguous()
        # Flatten the tokens
        active = (attention_mask[:, past_length:end] == 1)[..., :-1].view(-1)
        active_logits = shift_logits.view(-1, shift_logits.size(-1))[active]
        active_labels = shift_labels.view(-1)[active]
        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        loss = loss_fct(active_logits, active_labels)

As you can see, loss_fct is not used before it is declared for a second time, therefore it is safe to remove the first declaration.

Speuce avatar Dec 28 '23 20:12 Speuce