mlm_bert_traning
mlm_bert_traning copied to clipboard
基于mlm方式的带有纠错功能的拼音转汉字bert预训练模型,pinyin correcter,基于pytorch框架实现
mlm_bert_traning
基于mlm方式的带有纠错功能的拼音转汉字bert预训练模型,基于pytorch框架实现
依赖
python>3.6
torch==1.4.0
tranformers==3.1.0
scikit-leran==0.23.2
目的
将可能包含有错误的拼音解码成正确的汉字序列,可用于ASR(语音识别)的拼音输出进行纠错。原本代码是有给出模型预训练数据,这个是根据我的样本训练好的(因为数据太大放不上github,有需要请联系我或在issue中留言),应用到不同任务时效果可能会不适用,需要根据自己的训练样本重新训练模型。
训练
运行run.py,将其中的训练数据路径和测试数据路径改为你们的文件路径,文件格式类似data/trainpath和data/evalpath文件格式保持一致。
测试
运行test.py,将拼音序列输入,输出为每个位置的前5个最可能token以及对应的概率。
特色
自动根据样本基于模糊音算法(具体算法可在utils中手动修改)生成错误拼音进行训练,而不用手动标注,极大节省人力物力。
原理
原理如上图所示
- 构造训练样本。首先根据输入的正确拼音序列,任务将其变成错误的拼音,即模糊音替代算法,使得模型具备拼音纠错能力。
- 计算拼音loss。在bert的前6层的输出后添加全连接层,仅针对被替换的拼音进行计算loss,希望模型在前6层能学会纠错。
- 计算汉字loss。在bert的后6层的输出后,针对所有汉字进行计算loss,希望模型能学会拼音和汉字间的对应关系。
模糊音算法
模糊音算法采用两种结合的方式:相似声母和韵母,拼音编辑距离为1,进行生成。每部分生成的比例可手动设置,默认为MASK: 相似: 编辑=5:3:2。其中相似声母和韵母算法需根据具体场景进行调整。