SwissArmyTransformer icon indicating copy to clipboard operation
SwissArmyTransformer copied to clipboard

请教一个问题,使用mp_size=2时的loss应该怎么写

Open kunden0612 opened this issue 1 year ago • 1 comments

logits, *mems = model(inputs_ids, position_ids, attention_mask)
# print(logits.shape)
loss_func = CrossEntropyLoss(ignore_index=-100)
loss = loss_func(logits.view(-1, logits.size(-1)).float(), labels.view(-1))``

我是这样写的loss计算方式,会出现一个/opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/native/cuda/Loss.cu:242: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [15,0,0] Assertion t >= 0 && t < n_classes failed.`` 错误

kunden0612 avatar Aug 24 '23 02:08 kunden0612