capreolus
capreolus copied to clipboard
Extractors that generate transformers.tokenization_utils_base.BatchEncoding will cause error before training
Hi there.
I had an Extractor which was kind of copy of passagebert
and textbert
and I thought instead of decompressing tokenizer's result and embed them into a new dictionary like:
{
'positive_ids': tensor([1,2,3,...]),
'positive_mask': tensor([1,1,1,...]),
'positive_segments': tensor([1,1,1,...]),
}
it would be much better if I pass the tokenizer's results without any decompressing and reshaping. Bert tokenizer will yield a transformers.tokenization_utils_base.BatchEncoding
object which is a dictionary-like structure and can be passed to the model like bert_model(**tokens)
as you already know.
I assumed that I could just pass this object type and the code will run with no problem. something like this:
{
'positive_ids_and_mask': self.my_tokenizer('This is a test sentence'),
}
But it was not the case. In the pytorch trainer line 93, an error will be raised: https://github.com/capreolus-ir/capreolus/blob/0121f6e7efa3c1f19cc4704ac6f69747e1baa028/capreolus/trainer/pytorch.py#L93
AttributeError: 'dict' object has no attribute 'to'
v
here became a dict
and it is not a transformers.tokenization_utils_base.BatchEncoding
anymore so there is no to
attribure.
I investigated a little bit and I'm pretty sure the problem is caused by this line:
https://github.com/capreolus-ir/capreolus/blob/0121f6e7efa3c1f19cc4704ac6f69747e1baa028/capreolus/trainer/pytorch.py#L223
pytorch's DataLoader will accept transformers.tokenization_utils_base.BatchEncoding
but will yield a dictionary. Here is a show case:
>>> data = transformers.tokenization_utils_base.BatchEncoding({"test": [1,2,3]})
>>> type(data)
transformers.tokenization_utils_base.BatchEncoding
>>> for x in torch.utils.data.DataLoader([data]):
>>> print(x)
>>> print(type(x))
{'test': [tensor([1]), tensor([2]), tensor([3])]}
<class 'dict'>
I manually changed pytorch trainer code so it can convert dict
to transformers.tokenization_utils_base.BatchEncoding
but this is just a solution for my task and will cause problem for other non-bert models.
Thanks for pointing this out. I don't remember why we avoid using the dict from hgf's tokenizer
class directly, but this is something we should look into in the future when upgrading the version of transformers
. It may not be necessary to call to
on the tensors directly if this is happening inside the DataLoader
.