transformers
transformers copied to clipboard
Encode object type in Donut tokens
What does this PR do?
This makes use of encoded object types in text generated by Donut. It fixes a few issues:
- keys of the same name appearing at different levels of the JSON are no longer confused
- no more ambiguity between a dict and a list of length 1 containing a dict
Additionally, this allows us to keep track of which keys have been opened and closed so far. Now we can look ahead to find the token that closes the current element. This allows for much deeper nesting (beyond just 2 levels) without breaking.
There is some fault tolerance included in the look-ahead. If a closing token cannot be found or a new opening token is encountered unexpectedly, ambiguous parts of the text will be discarded and processing continues with the next part of the text that can be converted to JSON without any ambiguity.
This requires matching changes in json2token. I wasn't quite sure where to put this. I think at the moment, the Dataset code containing that method is only part of the tutorials. Would it make sense to add it here as well? Essentially, all that's needed is something like
from abc import ABCMeta, abstractmethod
class DonutDatasetMixin(ABCMeta):
added_tokens: list
@abstractmethod
def add_tokens(self, list_of_tokens: t.List[str]):
pass
def json2token(
self,
obj: t.Any,
update_special_tokens_for_json_key: bool = True,
sort_json_key: bool = True,
):
"""
Convert an ordered JSON object recursively into a token sequence
Args:
obj: Object to convert
update_special_tokens_for_json_key (bool):
Add encountered keys as special tokens to the processor's tokenizer
sort_json_key (bool): Whether to sort JSON keys in an object alphabetically
"""
if (obj_type := self.get_object_type(obj)) == "dict":
if len(obj) == 1 and "text_sequence" in obj:
return obj["text_sequence"]
else:
output = ""
if sort_json_key:
keys = sorted(obj.keys(), reverse=True)
else:
keys = obj.keys()
for k in keys:
v = obj[k]
v_obj_type = self.get_object_type(v)
if update_special_tokens_for_json_key:
self.add_tokens([rf"<s_{k}-{v_obj_type}>", rf"</s_{k}-{v_obj_type}>"])
output += (
rf"<s_{k}-{v_obj_type}>"
+ self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
+ rf"</s_{k}-{v_obj_type}>"
)
return output
elif obj_type == "list":
return r"<sep/>".join(
[
self.json2token(item, update_special_tokens_for_json_key, sort_json_key)
for item in obj
]
)
else:
obj = str(obj)
if f"<{obj}/>" in self.added_tokens:
obj = f"<{obj}/>" # for categorical special tokens`
return obj
@staticmethod
def get_object_type(obj: t.Any) -> t.Literal["list", "dict", "str"]:
if isinstance(obj, (list, np.ndarray)):
return "list"
if isinstance(obj, dict):
return "dict"
return "str"
Then the dataset can be constructed similarly to how it's already done in the tutorial:
class DonutDataset(Dataset, DonutDatasetMixin):
pass
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] Did you read the contributor guideline, Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [x] Did you write any new necessary tests?
Who can review?
@NielsRogge As promised, the improvements that we made to Donut's token2json. It works well with more complex JSON data structures, as demonstrated in the added tests.
Not sure why black is failing. make fixup doesn't change anything for me.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Hi @ts2095 , thanks for your contribution and sorry for the late reply.
Could you rebase your branch on main to make the CI green? Also, can you confirm this update is 100% backwards compatible?
Hi @ts2095 , sorry for the late reply here!
Would you be able to rebase your branch on the main branch of Transformers?
cc'ing @amyeroberts here for a review
@ts2095 There was a recent update on main, updating our CI images to run on Python 3.8, which I believe should resolve the import issue with from typing import Literal. Could you rebase to include these?
@amyeroberts We still support 3.7 so we cannot accept type-hints using Literal.
@ts2095 Can you confirm that this is backwards compatible and that previous token sequences result in the same json output?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.