CPT icon indicating copy to clipboard operation
CPT copied to clipboard

使用自定义数据集在bart-base-chinese的继续pretrain

Open Aureole-1210 opened this issue 2 years ago • 2 comments

我想要在自己的数据集上使用Huggingface已经开源的bart-base-chinese的继续pretrain流程,但是在training.py中load_checkpoint加载模型步骤遇到了一个问题。 load_checkpoint函数中,需要得到一个tracker file,如果不存在这个文件便会有警告“will not load any checkpoints and will start from random”,但是我希望从bart-base-chinese的基础上进行pretrain,请问这个tracker file应该如何设置?以及后面torch.load是应该直接加载pytorch_model.bin吗?但是它似乎不是代码里提及的model_optim_rng.pt。

Aureole-1210 avatar Feb 23 '22 12:02 Aureole-1210

你可以在模型初始化的时候加载预训练好的参数。比如训练BART的时候,在对应的https://github.com/fastnlp/CPT/blob/24eceed8b11a821f1ce8648ac2372714bc43c7a2/pretrain/megatron/model/bart_model.py#L45 这一行改成使用from_pretrained的方式加载模型参数。

如果要训练CPT也是类似的,修改cpt_model.py就行。

choosewhatulike avatar Feb 23 '22 13:02 choosewhatulike

好的,感谢回复!

Aureole-1210 avatar Feb 24 '22 08:02 Aureole-1210