torchkeras
torchkeras copied to clipboard
notebook 示例中使用 save_pretrained 保存 model/checkpoints 的方法有问题
问题描述
在若干个 notebook 示例中,使用 save_pretrained 方法保存模型是不对的,可以重新加载这些模型进行检查,发现权重没有变化。
# 仅仅保存lora可训练参数
# 覆盖了KerasModel中的load_ckpt和save_ckpt方法
def save_ckpt(self, ckpt_path='checkpoint', accelerator = None):
unwrap_net = accelerator.unwrap_model(self.net)
unwrap_net.save_pretrained(ckpt_path)
解决方法
👉根据accelerate官方文档中,保存model/checkpoints的方法👈,应该修正为
# 仅仅保存lora可训练参数
# 覆盖了KerasModel中的load_ckpt和save_ckpt方法
def save_ckpt(self, ckpt_path="checkpoint", accelerator=None):
unwrap_net = accelerator.unwrap_model(self.net)
unwrap_net.save_pretrained(
ckpt_path,
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(self.net),
)
希望后来者少走弯路,我检查了好久才找到这个问题😂