BERT-pytorch icon indicating copy to clipboard operation
BERT-pytorch copied to clipboard

why specify `ignore_index=0` in the NLLLoss function in BERTTrainer?

Open Jasmine969 opened this issue 2 years ago • 1 comments

trainer/pretrain.py

class BERTTrainer:
    def __init__(self, ...):
        ... 
        # Using Negative Log Likelihood Loss function for predicting the masked_token
        self.criterion = nn.NLLLoss(ignore_index=0)
        ...

I cannot understand why ignore index=0 is specified when calculating NLLLoss. If the ground truth of is_next is False (label = 0) in terms of the NSP task but BERT predicts True, then NLLLoss will be 0 (or nan)... so what's the aim of ignore_index = 0 ???

====================

Well, I've found that ignore_index = 0 is useful to the MLM task, but I still can't agree the NSP task should share the same NLLLoss with MLM.

Jasmine969 avatar Jul 07 '22 02:07 Jasmine969

see #32 change self.criterion = nn.NLLLoss(ignore_index=0) to self.criterion = nn.NLLLoss()

MingchangLi avatar Jan 10 '23 16:01 MingchangLi