Bert-Chinese-Text-Classification-Pytorch
Bert-Chinese-Text-Classification-Pytorch copied to clipboard
请问怎么解决连续加载两个模型,第一个模型参数干扰第二个模型的问题
请问一下,我已经训练好了两个模型,现在想用这两个模型做预测,在加载第一个模型的时候没问题,但是加载完成第一个模型之后,再去加载第二个模型的时候,就会报错,提示的错误是受到第一个模型参数的影响,请问这个应该怎么初始化torch才能让两个模型加载的时候不受相互影响嫩 class Model(nn.Module): def init(self, config): super(Model, self).init() self.bert = BertModel.from_pretrained(config.bert_path) for param in self.bert.parameters(): param.requires_grad = True self.fc = nn.Linear(config.hidden_size, config.num_classes)
def forward(self, x):
context = x[0] # 输入的句子
mask = x[2] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0]
_, pooled = self.bert(context, attention_mask=mask, output_all_encoded_layers=False)
out = self.fc(pooled)
return out
def __init__(self):
self.model_dict = dict()
for mod_dir in os.listdir(model_path):
with torch.no_grad():
mod_dir = os.path.join(model_path, mod_dir)
if not os.path.isdir(mod_dir):
continue
config = load_mod.Config(mod_dir)
model = load_mod.Model(config).to(config.device)
model.load_state_dict(torch.load(config.model_path, map_location='cuda' if torch.cuda.is_available() else 'cpu'))
model.eval()
self.model_dict[config.model_name] = {'model': model, 'config': config}
Traceback (most recent call last):
File "D:/workspace/python/Bert-Classification-Pytorch/Evaluate/model_utils.py", line 82, in