doctr icon indicating copy to clipboard operation
doctr copied to clipboard

Need the ability to blacklisting/whitelisting of characters

Open SlappyAUS opened this issue 2 years ago • 6 comments

Discussed in https://github.com/mindee/doctr/discussions/888

Originally posted by Xargonus April 7, 2022 Hello, is there currently a way to blacklist or whitelist characters used by the text recognition model?

SlappyAUS avatar Jul 16 '22 09:07 SlappyAUS

Thanks for opening the issue, let's first focus the discussions in #888 to avoid duplicates :)

frgfm avatar Jul 20 '22 09:07 frgfm

That's a great one too. Again, callback idea might shine here, since you can boost/deboost some characters, not completely blacklist them. For example, I often see 1 recognized as l or Q as 0/O, or : as 0. I'd like to avoid complete blacklist, but if I can, for example, boost and prioritize digits over letters or : over . it can solve the problem.

dchaplinsky avatar Feb 02 '24 10:02 dchaplinsky

@odulcy-mindee @dchaplinsky

A short example (with master) how we could reach this:

class MASTERPostProcessor(_MASTERPostProcessor):
    """Post processor for MASTER architectures"""

    def __call__(
        self,
        logits: torch.Tensor,
        blacklist: List[str] = ["a", "r", "g"]  # DUMMY :)
    ) -> List[Tuple[str, float]]:
        batch_size, max_seq_length, vocab_size = logits.size()

        # Find indices of blacklisted characters
        blacklist_indices = [self._embedding.index(char) for char in blacklist if char in self._embedding]

        # Adjust logits for blacklisted characters
        for i in range(batch_size):
            for j in range(max_seq_length):
                for idx in blacklist_indices:
                    logits[i, j, idx] = float('-inf')  # Set probability to negative infinity to exclude it

        # Compute pred with argmax for attention models
        out_idxs = logits.argmax(-1)
        # N x L
        probs = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1)
        # Take the minimum confidence of the sequence
        probs = probs.min(dim=1).values.detach().cpu()

        # Manual decoding
        word_values = [
            "".join(self._embedding[idx] for idx in encoded_seq).split("<eos>")[0]
            for encoded_seq in out_idxs.cpu().numpy()
        ]

        return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))

This would take the next "char" with the highest prob @SlappyAUS is this what you have had in mind ? Or full removal ?

For example:

blacklist = ["e", "l"]
normal_out = "hello"
blacklisted_out = "ho"

felixdittrich92 avatar Feb 13 '24 15:02 felixdittrich92

What if we make it a bit broader and instead provide an option to multiply the logits to some weight. For example, l and 1 and I can often be wrongly identified, and as soon as I expect mostly digits, I can upvote "digits" to break the almost-tie situations without killing all the rest of characters.

dchaplinsky avatar Feb 13 '24 15:02 dchaplinsky

@dchaplinsky Mhhh this would need any mapping to "close" characters ? :thinking:

felixdittrich92 avatar Feb 13 '24 15:02 felixdittrich92

Hey everyone :wave:

The use case here is to help out the text recognition part when you have more info about a subvocab, so now we need to assess whether that's worth addressing (I think it would be useful), what would be the API, and at which step this should take place.

My two cents:

  1. This will be quite useful to better leverage wide vocab model checkpoints without retraining them. We should work on this
  2. For the API, the issue with passing whitelist/blacklist of characters is that you need both and then you'll need an arbitrary order of priority. With minimal snippet, we can easily convert that to a weight vector which will be used by the model.
  3. For simplicity, I think this should be passed to the model (or post processor) and made accessible as an instance attribute as a Tensor. At inference time, this should be used in the call of the post processor either here https://github.com/mindee/doctr/blob/main/doctr/models/recognition/crnn/pytorch.py#L224 if at model level, or here https://github.com/mindee/doctr/blob/main/doctr/models/recognition/crnn/pytorch.py#L75 if at postprocessor level

Here is my suggested design for blacklist:

import torch
from doctr.models import crnn_vgg16_bn

blacklisted_chars = {str(num) for num in range(10)}
# Set the mask
vocab_mask = torch.tensor((0 if char in blacklisted_chars else 1 for char in vocab), dtype=torch.float32)
model = crnn_vgg16_bn(pretrained=True, vocab_mask=vocab_mask)

input_tensor = torch.rand(1, 3, 32, 128)
out = model(input_tensor)

and whitelist:

whitelisted_chars = {str(num) for num in range(10)}
vocab_mask = torch.tensor((1 if char in whitelisted_chars else 0 for char in vocab), dtype=torch.float32)

What do you think?

frgfm avatar Feb 23 '24 09:02 frgfm