pytorch-ie
pytorch-ie copied to clipboard
[WIP] fix training seq2seq
This was broken because pytorch-lightning tries to move the output of TransformerSeq2SeqTaskModule.collate to a device via pytorch_lightning.core.datamodule.LightningDataModule.transfer_batch_to_device that internally uses pytorch_lightning.utilities.apply_func.apply_to_collection. This method fails if any part of the input (arbitrary nested) is a frozen dataclass which is the case for our Annorations. To fix this, we simply remove the documents (and also metadata) from the batch since it is not used at all.
Notes:
- To see the described behavior, #197 is necessary because otherwise
apply_to_collectionfails even earlier since it tries to iterate over the dataclass fields. nox -p3.9 --session=tests_no_local_datasetspasses except one test with external cause (this is fixed in #198).
I have to look into this in more detail. The problem is that collate may return arbitrary types, not just dicts tensors. For instance, GENRE (the entity linking method) will even return a function in collate that is then passed on to the model.
Sounds reasonable. But functions should be fine, in general as they will be skipped by apply_to_collection (see link to docs and source code above), if this is the intended behavior.
Some further information: https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html#transfer-batch-to-device