bert_classfication
bert_classfication copied to clipboard
更换训练数据后,类别数更改就报错,原来18类别现在2类别
RuntimeError: Error(s) in loading state_dict for RobertaForSequenceClassification:
size mismatch for classifier.out_proj.weight: copying a param with shape torch.Size([18, 768]) from checkpoint, the shape in current model is torch.Size([2, 768]).
size mismatch for classifier.out_proj.bias: copying a param with shape torch.Size([18]) from checkpoint, the shape in current model is torch.Size([2]).
You may consider adding ignore_mismatched_sizes=True
in the model from_pretrained
method.