WeTextProcessing icon indicating copy to clipboard operation
WeTextProcessing copied to clipboard

新加规则,转换错误问题

Open jeeveenn opened this issue 3 months ago • 1 comments

您好,我们新增一个规则,实现把高铁的 高->G、动->D 的转换,在rules中新增规则如下:

from tn.processor import Processor
from pynini import string_file
from pynini.lib.pynutil import delete, insert


class FlightTrainCode(Processor):
    def __init__(self):
        super().__init__(name='flighttraincode')
        self.build_tagger()
        self.build_verbalizer()

    def build_tagger(self):
        digit = string_file('itn/chinese/data/number/digit.tsv')  # 1 ~ 9
        zero = string_file('itn/chinese/data/number/zero.tsv')  # 0
        # 有一个文件来定义识别高铁、飞机等特定字符的规则
        train_head = string_file('itn/chinese/data/train/train_tou.tsv')  # Example: 高 -> G

        digit_4 = digit + (digit | zero)**3 
        digit_3 = digit + (digit | zero)**2  
        digit_2 = digit + (digit | zero)  
        digit_1 = digit  

        tagger = (insert('head: "') + train_head + insert('"') +
                  insert(' number: "') +  (digit_4|digit_3|digit_2|digit_1) + insert('"'))
        self.tagger = self.add_tokens(tagger)

    def build_verbalizer(self):
        head = delete('head: "') + self.SIGMA + delete('"')
        number = delete(' number: "') + self.SIGMA + delete('"')
        verbalizer = head + number
        self.verbalizer = self.delete_tokens(verbalizer)

itn/chinese/inverse_normalizer.py的修改如下:

from tn.processor import Processor
from itn.chinese.rules.cardinal import Cardinal
from itn.chinese.rules.char import Char
from itn.chinese.rules.date import Date
from itn.chinese.rules.fraction import Fraction
from itn.chinese.rules.math import Math
from itn.chinese.rules.measure import Measure
from itn.chinese.rules.money import Money
from itn.chinese.rules.whitelist import Whitelist
from itn.chinese.rules.time import Time
from itn.chinese.rules.postprocessor import PostProcessor
from itn.chinese.rules.license_plate import LicensePlate
from itn.chinese.rules.flighttraincode import FlightTrainCode

from pynini.lib.pynutil import add_weight, delete
from importlib_resources import files


class InverseNormalizer(Processor):

    def __init__(self,
                 cache_dir=None,
                 overwrite_cache=False,
                 enable_standalone_number=True,
                 enable_0_to_9=False,
                 enable_million=False):
        super().__init__(name='inverse_normalizer', ordertype='itn')
        self.convert_number = enable_standalone_number
        self.enable_0_to_9 = enable_0_to_9
        self.enable_million = enable_million
        if cache_dir is None:
            cache_dir = files("itn")
        self.build_fst('zh_itn', cache_dir, overwrite_cache)

    def build_tagger(self):
        tagger = (add_weight(Date().tagger, 1.02)
                  | add_weight(Whitelist().tagger, 1.01)
                  | add_weight(Fraction().tagger, 1.05)
                  | add_weight(
                      Measure(enable_0_to_9=self.enable_0_to_9).tagger, 1.05)
                  | add_weight(
                      Money(enable_0_to_9=self.enable_0_to_9).tagger, 1.04)
                  | add_weight(Time().tagger, 1.05)
                  | add_weight(
                      Cardinal(self.convert_number, self.enable_0_to_9,
                               self.enable_million).tagger, 1.06)
                  | add_weight(Math().tagger, 1.10)
                  | add_weight(LicensePlate().tagger, 1.0)
                  | add_weight(FlightTrainCode().tagger, 1.07)
                  | add_weight(Char().tagger, 100)).optimize()

        tagger = tagger.star
        # remove the last space
        self.tagger = tagger @ self.build_rule(delete(' '), '', '[EOS]')

    def build_verbalizer(self):
        verbalizer = (Cardinal(self.convert_number, self.enable_0_to_9,
                               self.enable_million).verbalizer
                      | Char().verbalizer
                      | Date().verbalizer
                      | Fraction().verbalizer
                      | Math().verbalizer
                      | Measure(enable_0_to_9=self.enable_0_to_9).verbalizer
                      | Money(enable_0_to_9=self.enable_0_to_9).verbalizer
                      | Time().verbalizer
                      | LicensePlate().verbalizer
                      | FlightTrainCode().verbalizer
                      | Whitelist().verbalizer).optimize()
        postprocessor = PostProcessor(remove_interjections=True).processor

        self.verbalizer = (verbalizer @ postprocessor).star

识别结果出现顺序错误问题,输入: python -m itn --text "明天的动二" --overwrite_cache 输出如下:

" number: "2" }明" } char { value: "天" } char { value: "的" } flighttraincode { head: "D 
2 天的D 

jeeveenn avatar Mar 27 '24 02:03 jeeveenn