Bert-Chinese-Text-Classification-Pytorch icon indicating copy to clipboard operation
Bert-Chinese-Text-Classification-Pytorch copied to clipboard

请问怎么解决连续加载两个模型,第一个模型参数干扰第二个模型的问题

Open shengtaovvv opened this issue 3 years ago • 0 comments

请问一下,我已经训练好了两个模型,现在想用这两个模型做预测,在加载第一个模型的时候没问题,但是加载完成第一个模型之后,再去加载第二个模型的时候,就会报错,提示的错误是受到第一个模型参数的影响,请问这个应该怎么初始化torch才能让两个模型加载的时候不受相互影响嫩 class Model(nn.Module): def init(self, config): super(Model, self).init() self.bert = BertModel.from_pretrained(config.bert_path) for param in self.bert.parameters(): param.requires_grad = True self.fc = nn.Linear(config.hidden_size, config.num_classes)

def forward(self, x):
    context = x[0]  # 输入的句子
    mask = x[2]  # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0]
    _, pooled = self.bert(context, attention_mask=mask, output_all_encoded_layers=False)
    out = self.fc(pooled)
    return out



def __init__(self):
    self.model_dict = dict()
    for mod_dir in os.listdir(model_path):
        with torch.no_grad():
            mod_dir = os.path.join(model_path, mod_dir)
            if not os.path.isdir(mod_dir):
                continue
            config = load_mod.Config(mod_dir)
            model = load_mod.Model(config).to(config.device)
            model.load_state_dict(torch.load(config.model_path, map_location='cuda' if torch.cuda.is_available() else 'cpu'))
            model.eval()
            self.model_dict[config.model_name] = {'model': model, 'config': config}

Traceback (most recent call last): File "D:/workspace/python/Bert-Classification-Pytorch/Evaluate/model_utils.py", line 82, in bm = Bert_Model() File "D:/workspace/python/Bert-Classification-Pytorch/Evaluate/model_utils.py", line 23, in init model.load_state_dict(torch.load(config.model_path, map_location='cuda' if torch.cuda.is_available() else 'cpu')) File "D:\Program Files (x86)\Python\Python37\lib\site-packages\torch\nn\modules\module.py", line 1052, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for Model: size mismatch for fc.weight: copying a param with shape torch.Size([10, 768]) from checkpoint, the shape in current model is torch.Size([9, 768]). size mismatch for fc.bias: copying a param with shape torch.Size([10]) from checkpoint, the shape in current model is torch.Size([9]).

shengtaovvv avatar Mar 19 '21 07:03 shengtaovvv