torchkeras icon indicating copy to clipboard operation
torchkeras copied to clipboard

notebook 示例中使用 save_pretrained 保存 model/checkpoints 的方法有问题

Open lalalabox opened this issue 1 year ago • 0 comments

问题描述

在若干个 notebook 示例中,使用 save_pretrained 方法保存模型是不对的,可以重新加载这些模型进行检查,发现权重没有变化。

例如ChatGLM2_LoRA注释版.ipynb

# 仅仅保存lora可训练参数
# 覆盖了KerasModel中的load_ckpt和save_ckpt方法
def save_ckpt(self, ckpt_path='checkpoint', accelerator = None):
    unwrap_net = accelerator.unwrap_model(self.net)
    unwrap_net.save_pretrained(ckpt_path)

解决方法

👉根据accelerate官方文档中,保存model/checkpoints的方法👈,应该修正为

# 仅仅保存lora可训练参数
# 覆盖了KerasModel中的load_ckpt和save_ckpt方法
def save_ckpt(self, ckpt_path="checkpoint", accelerator=None):
    unwrap_net = accelerator.unwrap_model(self.net)
    unwrap_net.save_pretrained(
        ckpt_path,
        is_main_process=accelerator.is_main_process,
        save_function=accelerator.save,
        state_dict=accelerator.get_state_dict(self.net),
    )

希望后来者少走弯路,我检查了好久才找到这个问题😂

lalalabox avatar Sep 16 '24 12:09 lalalabox