doctr
doctr copied to clipboard
Need the ability to blacklisting/whitelisting of characters
Discussed in https://github.com/mindee/doctr/discussions/888
Originally posted by Xargonus April 7, 2022 Hello, is there currently a way to blacklist or whitelist characters used by the text recognition model?
Thanks for opening the issue, let's first focus the discussions in #888 to avoid duplicates :)
That's a great one too. Again, callback idea might shine here, since you can boost/deboost some characters, not completely blacklist them. For example, I often see 1 recognized as l or Q as 0/O, or : as 0. I'd like to avoid complete blacklist, but if I can, for example, boost and prioritize digits over letters or : over . it can solve the problem.
@odulcy-mindee @dchaplinsky
A short example (with master) how we could reach this:
class MASTERPostProcessor(_MASTERPostProcessor):
"""Post processor for MASTER architectures"""
def __call__(
self,
logits: torch.Tensor,
blacklist: List[str] = ["a", "r", "g"] # DUMMY :)
) -> List[Tuple[str, float]]:
batch_size, max_seq_length, vocab_size = logits.size()
# Find indices of blacklisted characters
blacklist_indices = [self._embedding.index(char) for char in blacklist if char in self._embedding]
# Adjust logits for blacklisted characters
for i in range(batch_size):
for j in range(max_seq_length):
for idx in blacklist_indices:
logits[i, j, idx] = float('-inf') # Set probability to negative infinity to exclude it
# Compute pred with argmax for attention models
out_idxs = logits.argmax(-1)
# N x L
probs = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1)
# Take the minimum confidence of the sequence
probs = probs.min(dim=1).values.detach().cpu()
# Manual decoding
word_values = [
"".join(self._embedding[idx] for idx in encoded_seq).split("<eos>")[0]
for encoded_seq in out_idxs.cpu().numpy()
]
return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))
This would take the next "char" with the highest prob @SlappyAUS is this what you have had in mind ? Or full removal ?
For example:
blacklist = ["e", "l"]
normal_out = "hello"
blacklisted_out = "ho"
What if we make it a bit broader and instead provide an option to multiply the logits to some weight. For example, l and 1 and I can often be wrongly identified, and as soon as I expect mostly digits, I can upvote "digits" to break the almost-tie situations without killing all the rest of characters.
@dchaplinsky Mhhh this would need any mapping to "close" characters ? :thinking:
Hey everyone :wave:
The use case here is to help out the text recognition part when you have more info about a subvocab, so now we need to assess whether that's worth addressing (I think it would be useful), what would be the API, and at which step this should take place.
My two cents:
- This will be quite useful to better leverage wide vocab model checkpoints without retraining them. We should work on this
- For the API, the issue with passing whitelist/blacklist of characters is that you need both and then you'll need an arbitrary order of priority. With minimal snippet, we can easily convert that to a weight vector which will be used by the model.
- For simplicity, I think this should be passed to the model (or post processor) and made accessible as an instance attribute as a Tensor. At inference time, this should be used in the call of the post processor either here https://github.com/mindee/doctr/blob/main/doctr/models/recognition/crnn/pytorch.py#L224 if at model level, or here https://github.com/mindee/doctr/blob/main/doctr/models/recognition/crnn/pytorch.py#L75 if at postprocessor level
Here is my suggested design for blacklist:
import torch
from doctr.models import crnn_vgg16_bn
blacklisted_chars = {str(num) for num in range(10)}
# Set the mask
vocab_mask = torch.tensor((0 if char in blacklisted_chars else 1 for char in vocab), dtype=torch.float32)
model = crnn_vgg16_bn(pretrained=True, vocab_mask=vocab_mask)
input_tensor = torch.rand(1, 3, 32, 128)
out = model(input_tensor)
and whitelist:
whitelisted_chars = {str(num) for num in range(10)}
vocab_mask = torch.tensor((1 if char in whitelisted_chars else 0 for char in vocab), dtype=torch.float32)
What do you think?