tokenizers icon indicating copy to clipboard operation
tokenizers copied to clipboard

Wrong alignments after calling `NormalizedString.replace()`

Open t-yamamura opened this issue 3 years ago • 10 comments

Hi,

I tried to change normalized property in NormalizedString using NormalizedString.replace() to do my own normalization process in a custom pretokenizer (https://github.com/huggingface/tokenizers/blob/master/bindings/python/examples/custom_components.py). However, I found that alignments of NormalizedString after the replacement was wrong.

This bug seems to be occurred when the first character of a sentence are replaced.

let mut s = NormalizedString::from("abc");
s.replace("abc", "xxx");
assert_eq!(s.get(), "xxx");
eprint!("{:?}", s.alignments);
---
[(2, 3), (2, 3), (2, 3)]
let mut s = NormalizedString::from("abc");
s.replace("b", "xxx").expect("should work");
assert_eq!(s.get(), "axxxc");
eprint!("{:?}", s.alignments);
---
[(0, 1), (1, 2), (1, 2), (1, 2), (2, 3)]

t-yamamura avatar Jan 26 '22 06:01 t-yamamura

Hi @t-yamamura,

When you say, wrong do you mind sharing what you expect. There are at least 3 versions that are debatable at least IMO:

  • [(0, 3), (0, 3), (0, 3)] . Each character "comes" from the original string. This is the "pure" solution I think, but it's a bit odd too that a many single characters are coming from a potentially large block.
  • [(0, 1), (1, 2), (2, 3)]. Each character gets "replaced" with a character from the previous length. But can we really say that the first "x" comes from "a" ? And since everything is treated as char this is likely be become odd on unicode since we might have combining characters, which get printed as one "thing" but might get affected different spots.
  • [(2, 3), (2, 3), (2, 3)]. We start by dropping all the original characters ("abc") from the string, and create entirely new characters afterwards (current behavior). It works, but definitely doesn't track all the original match, which does seem odd sometimes.

Narsil avatar Jan 26 '22 10:01 Narsil

I'm sorry I didn't explain what I expected clearly enough.

Firstly, I would like to make sure that alignment is mapping from normalized string to original string, isn't it? If that's so, I think it is better to preserve the alignment of the original string, no matter how much the original string is replaced (even if a single character is replaced by multiple characters). So I agree to the first solution.

[(0, 3), (0, 3), (0, 3)] . Each character "comes" from the original string. This is the "pure" solution I think, but it's a bit odd too that a many single characters are coming from a potentially large block.

t-yamamura avatar Jan 26 '22 12:01 t-yamamura

The current behavior means that only the last character was replaced by all new characters at once and non-last characters were gone into oblivion, leaving a hole in the original string. Which is not what the replace operation should mean, if I am understanding the reasoning behind the alignment correctly. While the second behavior can be debatable, the first one would be more natural in my opinion.

Anyway, to provide more context, we are trying to implement a pretokenizer which can replace a token contents with another string (without strict character to character mapping) and that breaks token offset calculation, which we localized to this issue.

eiennohito avatar Jan 26 '22 12:01 eiennohito

I see.

I guess it makes sense then.

I think this is pretty core, so this change might take more time than we expect. I'll try to push a PR soon so we can run some tests and get some eyes on it to make sure we're not breaking other things.

fyi, dealing with chars/bytes/unicode stuff is... error-prone at least :D

Narsil avatar Jan 26 '22 13:01 Narsil

On another note, would option 2 be viable too ?

Maybe Option 1 is actually impossible/very tricky if it breaks some assumptions made elsewhere (I don't think it does, but it might)

Narsil avatar Jan 26 '22 13:01 Narsil

It's certainly error-prone! Also, we could have probably worked around this issue if it were possible to perform dll-ish rust <-> rust calls with low-level APIs. Right now you either need to completely link Rust tokenizers into your binary (which is OK for rust-only programs) or go via bindings.

For our use case option 2 will be viable as well (and may be the best way when taking possible further sub-segmentation in account), it's your call on what seems to be more sound here. From the POV of Python bindings, per-character alignments are not accessible (they are accessible only from Rust?) and only tokenization offsets are visible from the outside.

eiennohito avatar Jan 26 '22 23:01 eiennohito

Yes, actually even in rust alignements are not public either. They are only exposed through the crate.

AFAIK it's entirely on purpose since alignements are just a internal way of keeping track of various things, but if we came up with a much better algo/data structure we would still be free to use it. offsets is what users should rely on.

Are there any reasons you cannot use offsets directly in your use case.

Re-reading your use case:

Anyway, to provide more context, we are trying to implement a pretokenizer which can replace a token contents with another string (without strict character to character mapping) and that breaks token offset calculation, which we localized to this issue.

Do you mind sharing a dummy example of what you were trying to do and what failed ? It could really help build the case for the change. Again, option 3 (current behavior) is also defendable from the pure alignements standpoint, so it's also reassuring to make the change because we found a blatant odd/uneasy/wrong thing in the bindings

Narsil avatar Jan 27 '22 08:01 Narsil

I will share a dummy code that causes a mismatch with the offset of the original string after NormalizedString.replace(). What we want to do is to rewrite the NormalizedString after the tokenization to normalize some tokens like stemming and lemmatization. (Since we are targeting Japanese, the normalization needs to be done after the tokenization.)

Here is a dummy code that reproduces the error, although it is not the tokenization and normalization algorithm that we originally wanted to do.

import re
from typing import List, Tuple

from tokenizers import NormalizedString, PreTokenizedString, Regex, Tokenizer
from tokenizers.models import WordPiece
from tokenizers.pre_tokenizers import PreTokenizer


class CustomPreTokenizer:
    def custom_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]:
        normalized_strings = []
        for begin_index, end_index in self.whitespace_split(str(normalized_string)):
            ns = normalized_string[begin_index:end_index]
            ns.replace(Regex("^.*$"), str(ns.normalized).upper()) # replaces all lowercase to uppercase
            normalized_strings.append(ns)
        return normalized_strings
            
    @staticmethod
    def whitespace_split(sentence: str) -> List[Tuple[int, int]]:
        return [match.span() for match in re.finditer('[^\s+]{1,}', sentence)]
    
    def pre_tokenize(self, pretok: PreTokenizedString):
        pretok.split(self.custom_split)

tok = Tokenizer(WordPiece(vocab={"[UNK]": 0}))  # vocab won't affect this issue.
tok.pre_tokenizer = PreTokenizer.custom(CustomPreTokenizer())
pass_sentence = "a b c"
print("Original:", CustomPreTokenizer.whitespace_split(pass_sentence))
print("Encoded:", tok.encode(pass_sentence).offsets)
---
Original: [(0, 1), (2, 3), (4, 5)]
Encoded: [(0, 1), (2, 3), (4, 5)]  # ok
fail_sentence1 = "a b cc"
print("Original:", CustomPreTokenizer.whitespace_split(fail_sentence1))
print("Encoded:", tok.encode(fail_sentence1).offsets)
---
Original: [(0, 1), (2, 3), (4, 6)]
Encoded: [(0, 1), (2, 3), (5, 6)]  # leaving a hole in 4
fail_sentence2 = "a b cc ddd"
print("Original:", CustomPreTokenizer.whitespace_split(fail_sentence2))
print("Encoded:", tok.encode(fail_sentence2).offsets)
---
Original: [(0, 1), (2, 3), (4, 6), (7, 10)]
Encoded: [(0, 1), (2, 3), (5, 6), (9, 10)]  # leaving a hole in 4, 7, and 8

The current behavior means that only the last character was replaced by all new characters at once and non-last characters were gone into oblivion, leaving a hole in the original string.

As mentioned above, the alignment changes as if only the last character was replaced. So, NormalizedString.replace() for multiple characters will create a hole in the original string.

Interestingly, this error dose not occur by using the normalizer via CustomNormalizer.

import re
from typing import List, Tuple

from tokenizers import NormalizedString, PreTokenizedString, Regex, Tokenizer
from tokenizers.models import WordPiece
from tokenizers.pre_tokenizers import PreTokenizer
from tokenizers.normalizers import Normalizer


class CustomPreTokenizer:
    def custom_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]:
        normalized_strings = []
        for begin_index, end_index in self.whitespace_split(str(normalized_string)):
            ns = normalized_string[begin_index:end_index]
            normalized_strings.append(ns)
        return normalized_strings
            
    @staticmethod
    def whitespace_split(sentence: str) -> List[Tuple[int, int]]:
        return [match.span() for match in re.finditer('[^\s+]{1,}', sentence)]
    
    def pre_tokenize(self, pretok: PreTokenizedString):
        pretok.split(self.custom_split)

class CustomNormalizer:
    def normalize(self, normalized: NormalizedString):
        normalized.replace(Regex("^.*$"), str(ns.normalized).upper()) # replaces all lowercase to uppercase

        
tok = Tokenizer(WordPiece(vocab={"[UNK]": 0}))  # vocab won't affect this issue.
tok.pre_tokenizer = PreTokenizer.custom(CustomPreTokenizer())
tok.normalizer = Normalizer.custom(CustomNormalizer())
pass_sentence = "a b c"
print("Original:", CustomPreTokenizer.whitespace_split(pass_sentence))
print("Encoded:", tok.encode(pass_sentence).offsets)
---
Original: [(0, 1), (2, 3), (4, 5)]
Encoded: [(0, 1), (2, 3), (4, 5)]  # ok
fail_sentence1 = "a b cc"
print("Original:", CustomPreTokenizer.whitespace_split(fail_sentence1))
print("Encoded:", tok.encode(fail_sentence1).offsets)
---
Original: [(0, 1), (2, 3), (4, 6)]
Encoded: [(0, 1), (2, 3), (4, 6)]  # ok
fail_sentence2 = "a b cc ddd"
print("Original:", CustomPreTokenizer.whitespace_split(fail_sentence2))
print("Encoded:", tok.encode(fail_sentence2).offsets)
---
Original: [(0, 1), (2, 3), (4, 6), (7, 10)]
Encoded: [(0, 1), (2, 3), (4, 6), (7, 10)]  # ok

t-yamamura avatar Jan 27 '22 11:01 t-yamamura

Could you checkout if this PR solved your issue : https://github.com/huggingface/tokenizers/pull/894 ?

Narsil avatar Jan 28 '22 13:01 Narsil

@Narsil Thank you for your PR(#894). I confirmed that original alignments corresponds to encoded (normalized) alignments in above the code (https://github.com/huggingface/tokenizers/issues/892#issuecomment-1023122286) with #894.

However, I found that some normalizers produce odd alignments. Here is my test code.

import re
from typing import List, Tuple

from tokenizers import NormalizedString, PreTokenizedString, Regex, Tokenizer
from tokenizers.models import WordPiece
from tokenizers.pre_tokenizers import PreTokenizer
from tokenizers.normalizers import Normalizer


class CustomPreTokenizer:
    def __init__(self, normalizer=None):
        self.normalizer = normalizer
    
    def custom_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]:
        normalized_strings = []
        for begin_index, end_index in self.whitespace_split(str(normalized_string)):
            ns = normalized_string[begin_index:end_index]
            if self.normalizer is None:
                pass
            elif self.normalizer == 'upper_by_replace':
                ns.replace(Regex("^.*$"), str(ns.normalized).upper()) # replaces all lowercase to uppercase
            elif self.normalizer == 'upper_by_append':
                normalized = str(ns.normalized).upper()  # replaces all lowercase to uppercase in annother way
                ns.map(lambda x: ' ')
                ns.append(normalized)
                ns.lstrip()
            elif self.normalizer == 'append_suffix':
                ns.replace(Regex("^.*$"), str(ns.normalized) + 'able')
            elif self.normalizer == 'remove_suffix':
                if len(ns.normalized) >= 2:
                    ns.replace(Regex("^.*$"), str(ns.normalized)[:-1])
            else:
                raise ValueError('Invalid normalizer.')
            normalized_strings.append(ns)
        return normalized_strings
            
    @staticmethod
    def whitespace_split(sentence: str) -> List[Tuple[int, int]]:
        return [match.span() for match in re.finditer('[^\s+]{1,}', sentence)]
    
    def pre_tokenize(self, pretok: PreTokenizedString):
        pretok.split(self.custom_split)

def compare_alignments(tokenizer, sentence):
    print("Original:", CustomPreTokenizer.whitespace_split(sentence))
    print("Encoded :", tokenizer.encode(sentence).offsets)
# Case1: without normalizer
tok = Tokenizer(WordPiece(vocab={"[UNK]": 0}))  # vocab won't affect this issue.
tok.pre_tokenizer = PreTokenizer.custom(CustomPreTokenizer(normalizer=None))
compare_alignments(tok, 'a bb ccc dddd')
---
Original: [(0, 1), (2, 4), (5, 8), (9, 13)]
Encoded : [(0, 1), (2, 4), (5, 8), (9, 13)]  # OK.
# Case2: uppercase normalizer by `replace()`
tok = Tokenizer(WordPiece(vocab={"[UNK]": 0}))
tok.pre_tokenizer = PreTokenizer.custom(CustomPreTokenizer(normalizer='upper_by_replace'))
compare_alignments(tok, 'a bb ccc dddd')
---
Original: [(0, 1), (2, 4), (5, 8), (9, 13)]
Encoded : [(0, 1), (2, 4), (5, 8), (9, 13)]  # OK.
# Case3: uppercase normalizer by `map()`, `append()`, and `lstrip()`
tok = Tokenizer(WordPiece(vocab={"[UNK]": 0}))
tok.pre_tokenizer = PreTokenizer.custom(CustomPreTokenizer(normalizer='upper_by_append'))
compare_alignments(tok, 'a bb ccc dddd')
---
Original: [(0, 1), (2, 4), (5, 8), (9, 13)]
Encoded : [(0, 1), (3, 4), (7, 8), (12, 13)]  # leaving a hole in 2, 5, 6, 9, 10, and 11.
Comparison Case1 and Case2

As you can see, your PR enables us to solve the problem that original alignments do not corresponds encoded alignments after the normalization by NormalziedString.repalce().

Comparison Case2 and Case3

However, the Case2 and the Case3 output different results despite the same normalization process.

# Case4: Append suffix to each token
tok = Tokenizer(WordPiece(vocab={"[UNK]": 0}))
tok.pre_tokenizer = PreTokenizer.custom(CustomPreTokenizer(normalizer='append_suffix'))
compare_alignments(tok, 'a bb ccc dddd')
---
Original: [(0, 1), (2, 4), (5, 8), (9, 13)]
Encoded : [(0, 1), (2, 4), (5, 8), (9, 13)]  # OK.
# Case5: 
tok = Tokenizer(WordPiece(vocab={"[UNK]": 0}))
tok.pre_tokenizer = PreTokenizer.custom(CustomPreTokenizer(normalizer='remove_suffix'))
compare_alignments(tok, 'a bb ccc dddd')
---
Original: [(0, 1), (2, 4), (5, 8), (9, 13)]
Encoded : [(0, 1), (3, 4), (6, 8), (10, 13)]  # leaving a hole in 2, 5, and 9.
Comparison Case4 and Case5

I also checked the results of two normalizers that increased or decreased the number of characters. For the string-increasing normalizer (Case4), the encoded alignments are appropriate. However, For the string-decreasing normalizer (Case5), the encoded alignments outputs a different result for Case4. It seems odd that the results for Case4 and Case5 are not consistent.

t-yamamura avatar Jan 31 '22 13:01 t-yamamura

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

github-actions[bot] avatar Mar 01 '24 01:03 github-actions[bot]