CPT
CPT copied to clipboard
使用自定义数据集在bart-base-chinese的继续pretrain
我想要在自己的数据集上使用Huggingface已经开源的bart-base-chinese的继续pretrain流程,但是在training.py中load_checkpoint加载模型步骤遇到了一个问题。 load_checkpoint函数中,需要得到一个tracker file,如果不存在这个文件便会有警告“will not load any checkpoints and will start from random”,但是我希望从bart-base-chinese的基础上进行pretrain,请问这个tracker file应该如何设置?以及后面torch.load是应该直接加载pytorch_model.bin吗?但是它似乎不是代码里提及的model_optim_rng.pt。
你可以在模型初始化的时候加载预训练好的参数。比如训练BART的时候,在对应的https://github.com/fastnlp/CPT/blob/24eceed8b11a821f1ce8648ac2372714bc43c7a2/pretrain/megatron/model/bart_model.py#L45 这一行改成使用from_pretrained的方式加载模型参数。
如果要训练CPT也是类似的,修改cpt_model.py就行。
好的,感谢回复!