FlagAI icon indicating copy to clipboard operation
FlagAI copied to clipboard

RoBERTa-base-ch-ner模型加载错误

Open monster476 opened this issue 2 years ago • 3 comments

Description

我想用RoBERTa-base-ch-ner模型做自己数据集的任务,从头进行训练,我只需要五类标签,因此我修改target = ["O","B-SKI", "I-SKI", "B-CER", "I-CER"]但是报错维数不匹配,因为原来的模型是7类,现在是5类,这个是一定要加载预训练模型吗?那如果我的类超过7类怎么解决?

Alternatives

No response

monster476 avatar May 23 '23 14:05 monster476

SequenceClassification任务改class_num

BAAI-OpenPlatform avatar May 26 '23 14:05 BAAI-OpenPlatform

代码如下:

target = ["O","B-SKI", "I-SKI", "B-CER", "I-CER"]

save_dir = "./checkpoints_ner/"

auto_loader = AutoLoader(task_name,
                         model_name="RoBERTa-base-ch",
                         model_dir=model_dir,
                         class_num=len(target))
model = auto_loader.get_model()
tokenizer = auto_loader.get_tokenizer()

trainer = Trainer(env_type="pytorch",
                  experiment_name="roberta_ner",
                  batch_size=4,
                  gradient_accumulation_steps=1,
                  lr=2e-5,
                  weight_decay=1e-3,
                  epochs=100,
                  log_interval=10,
                  eval_interval=100,
                  load_dir=None,
                  pytorch_device=device,
                  save_dir=save_dir,
                  save_interval=100)

我已经修改了类别数,但是还是报错,请问如何解决? Traceback (most recent call last): File "main.py", line 26, in auto_loader = AutoLoader(task_name, File "/root/miniconda3/lib/python3.8/site-packages/flagai/auto_model/auto_loader.py", line 205, in init self.model = getattr(LazyImport(self.model_name[0]), File "/root/miniconda3/lib/python3.8/site-packages/flagai/model/base_model.py", line 151, in from_pretrain return load_local(checkpoint_path, only_download_config=only_download_config) File "/root/miniconda3/lib/python3.8/site-packages/flagai/model/base_model.py", line 87, in load_local model.load_weights(checkpoint_path) File "/root/miniconda3/lib/python3.8/site-packages/flagai/model/bert_model.py", line 546, in load_weights load_extend_layer_weight( File "/root/miniconda3/lib/python3.8/site-packages/flagai/model/bert_model.py", line 333, in load_extend_layer_weight self.load_state_dict(checkpoints_save, strict=False) File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for BertForSequenceLabeling: size mismatch for final_dense.weight: copying a param with shape torch.Size([7, 768]) from checkpoint, the shape in current model is torch.Size([5, 768]). size mismatch for final_dense.bias: copying a param with shape torch.Size([7]) from checkpoint, the shape in current model is torch.Size([5]).

monster476 avatar May 26 '23 15:05 monster476

class_num 跟ckpt不一致。如果要修改class_num,需要加载其他部分,自己修改task head。

ftgreat avatar Jun 13 '23 09:06 ftgreat

先关闭,如有问题重新打开issue,谢谢

ftgreat avatar Jun 22 '23 11:06 ftgreat