d2l-en icon indicating copy to clipboard operation
d2l-en copied to clipboard

The mlm loss computation in the function _get_batch_loss_bert seems wrong in d2l pytorch code

Open lyconghk opened this issue 1 year ago • 2 comments

In my opinion, the BERT pretrain batch loss in the function _get_batch_loss_bert is not correct. The following is the detail:

The CrossEntropyLoss is initialized with default reduction 'mean', loss = nn.CrossEntropyLoss() In the function _get_batch_loss_bert, mlm_loss and nsp_loss used the same input instance loss for computation. mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *mlm_weights_X.reshape(-1, 1) Since the reduction='mean', the resultant tensor of 'loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) ' is a scalar tensor, it leads a problem for mlm loss computation by positionwise product with the input tensor mlm_weights_X.

lyconghk avatar Jan 06 '24 11:01 lyconghk

Agree with you @lyconghk . Have you come up with any better solution to apply mlm_weights_X in mlm_l calculation?

The weight parameter of PyTorch CrossEntropyLoss does not seem to support mlm_weights_X in the way that the MXNet does. I guess that is why the PyTorch version of _get_batch_loss_bert calculate mlm_l in this way. It tries to reduce the impact of padded tokens to mlm_l, but it does not use mlm_weights_X in an correct way.

gab-chen avatar Jan 28 '24 20:01 gab-chen

How about just use the package torch.nn import functional to calculate the two cross entropy loss of mlm and nsp? And remove the input parameter loss in the function _get_batch_loss_ber.

from torch.nn import functional as F

mlm_l = F.cross_entropy(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1), reduction='none')

nsp_l = F.cross_entropy(nsp_Y_hat, nsp_Y, reduction='mean')

lyconghk avatar Jan 29 '24 02:01 lyconghk