WeTextProcessing
WeTextProcessing copied to clipboard
新加规则,转换错误问题
您好,我们新增一个规则,实现把高铁的 高->G、动->D 的转换,在rules中新增规则如下:
from tn.processor import Processor
from pynini import string_file
from pynini.lib.pynutil import delete, insert
class FlightTrainCode(Processor):
def __init__(self):
super().__init__(name='flighttraincode')
self.build_tagger()
self.build_verbalizer()
def build_tagger(self):
digit = string_file('itn/chinese/data/number/digit.tsv') # 1 ~ 9
zero = string_file('itn/chinese/data/number/zero.tsv') # 0
# 有一个文件来定义识别高铁、飞机等特定字符的规则
train_head = string_file('itn/chinese/data/train/train_tou.tsv') # Example: 高 -> G
digit_4 = digit + (digit | zero)**3
digit_3 = digit + (digit | zero)**2
digit_2 = digit + (digit | zero)
digit_1 = digit
tagger = (insert('head: "') + train_head + insert('"') +
insert(' number: "') + (digit_4|digit_3|digit_2|digit_1) + insert('"'))
self.tagger = self.add_tokens(tagger)
def build_verbalizer(self):
head = delete('head: "') + self.SIGMA + delete('"')
number = delete(' number: "') + self.SIGMA + delete('"')
verbalizer = head + number
self.verbalizer = self.delete_tokens(verbalizer)
itn/chinese/inverse_normalizer.py的修改如下:
from tn.processor import Processor
from itn.chinese.rules.cardinal import Cardinal
from itn.chinese.rules.char import Char
from itn.chinese.rules.date import Date
from itn.chinese.rules.fraction import Fraction
from itn.chinese.rules.math import Math
from itn.chinese.rules.measure import Measure
from itn.chinese.rules.money import Money
from itn.chinese.rules.whitelist import Whitelist
from itn.chinese.rules.time import Time
from itn.chinese.rules.postprocessor import PostProcessor
from itn.chinese.rules.license_plate import LicensePlate
from itn.chinese.rules.flighttraincode import FlightTrainCode
from pynini.lib.pynutil import add_weight, delete
from importlib_resources import files
class InverseNormalizer(Processor):
def __init__(self,
cache_dir=None,
overwrite_cache=False,
enable_standalone_number=True,
enable_0_to_9=False,
enable_million=False):
super().__init__(name='inverse_normalizer', ordertype='itn')
self.convert_number = enable_standalone_number
self.enable_0_to_9 = enable_0_to_9
self.enable_million = enable_million
if cache_dir is None:
cache_dir = files("itn")
self.build_fst('zh_itn', cache_dir, overwrite_cache)
def build_tagger(self):
tagger = (add_weight(Date().tagger, 1.02)
| add_weight(Whitelist().tagger, 1.01)
| add_weight(Fraction().tagger, 1.05)
| add_weight(
Measure(enable_0_to_9=self.enable_0_to_9).tagger, 1.05)
| add_weight(
Money(enable_0_to_9=self.enable_0_to_9).tagger, 1.04)
| add_weight(Time().tagger, 1.05)
| add_weight(
Cardinal(self.convert_number, self.enable_0_to_9,
self.enable_million).tagger, 1.06)
| add_weight(Math().tagger, 1.10)
| add_weight(LicensePlate().tagger, 1.0)
| add_weight(FlightTrainCode().tagger, 1.07)
| add_weight(Char().tagger, 100)).optimize()
tagger = tagger.star
# remove the last space
self.tagger = tagger @ self.build_rule(delete(' '), '', '[EOS]')
def build_verbalizer(self):
verbalizer = (Cardinal(self.convert_number, self.enable_0_to_9,
self.enable_million).verbalizer
| Char().verbalizer
| Date().verbalizer
| Fraction().verbalizer
| Math().verbalizer
| Measure(enable_0_to_9=self.enable_0_to_9).verbalizer
| Money(enable_0_to_9=self.enable_0_to_9).verbalizer
| Time().verbalizer
| LicensePlate().verbalizer
| FlightTrainCode().verbalizer
| Whitelist().verbalizer).optimize()
postprocessor = PostProcessor(remove_interjections=True).processor
self.verbalizer = (verbalizer @ postprocessor).star
识别结果出现顺序错误问题,输入:
python -m itn --text "明天的动二" --overwrite_cache
输出如下:
" number: "2" }明" } char { value: "天" } char { value: "的" } flighttraincode { head: "D
2 天的D