pytorch-ie icon indicating copy to clipboard operation
pytorch-ie copied to clipboard

[WIP] fix training seq2seq

Open ArneBinder opened this issue 3 years ago • 2 comments

This was broken because pytorch-lightning tries to move the output of TransformerSeq2SeqTaskModule.collate to a device via pytorch_lightning.core.datamodule.LightningDataModule.transfer_batch_to_device that internally uses pytorch_lightning.utilities.apply_func.apply_to_collection. This method fails if any part of the input (arbitrary nested) is a frozen dataclass which is the case for our Annorations. To fix this, we simply remove the documents (and also metadata) from the batch since it is not used at all.

Notes:

  • To see the described behavior, #197 is necessary because otherwise apply_to_collection fails even earlier since it tries to iterate over the dataclass fields.
  • nox -p3.9 --session=tests_no_local_datasets passes except one test with external cause (this is fixed in #198).

ArneBinder avatar Jul 30 '22 18:07 ArneBinder

I have to look into this in more detail. The problem is that collate may return arbitrary types, not just dicts tensors. For instance, GENRE (the entity linking method) will even return a function in collate that is then passed on to the model.

ChristophAlt avatar Jul 31 '22 08:07 ChristophAlt

Sounds reasonable. But functions should be fine, in general as they will be skipped by apply_to_collection (see link to docs and source code above), if this is the intended behavior.

Some further information: https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html#transfer-batch-to-device

ArneBinder avatar Jul 31 '22 12:07 ArneBinder