spacy-transformers icon indicating copy to clipboard operation
spacy-transformers copied to clipboard

Transformer: add update_listeners_in_predict option

Open danieldk opened this issue 3 years ago • 1 comments

Draft: still needs docs, but I first wanted to discuss this proposal.

The output of a transformer is passed through in two different ways:

  • Prediction: the data is passed through the Doc._.trf_data attribute.
  • Training: the data is broadcast directly to the transformer's listeners.

However, the Transformer.predict method breaks the strict separation between training and prediction by also broadcasting transformer outputs to its listeners. This was added (I think) to support training with a frozen transformer.

However, this breaks down when we are training a model with an unfrozen transformer when the transformer is also in annotating_components. The transformer will first (as part of its update step) broadcast the tensors and backprop function to its listeners. However, then when acting as an annotating component, it would immediately override its own output and clear the backprop function. As a result, gradients will not flow into the transformer.

This change fixes this issue by adding the update_listeners_in_predict option, which is enabled by default. When this option is disabled, the tensors will not be broadcast to listeners in predict.


Alternatives considered:

  • Yanking the listener code from predict: breaks our current semantics, would make it harder to train with a frozen transformer.
  • Checking in the listener if the tensors that we are receiving is the same batch ID as we already have. Don't update if we already have the same batch with a backprop function. I thought this is a bit fragile, because it breaks when batching differs between training and prediction (?).

danieldk avatar Aug 04 '22 11:08 danieldk

Let me close and reopen to check the CI.

adrianeboyd avatar Aug 08 '22 07:08 adrianeboyd

I agree it looks like there's some inconsistencies in how the embeddings are passed through and the clear cut between train/predict that we envisioned isn't actually reflected in the code.

Alternatives considered:

* Yanking the listener code from `predict`: breaks our current semantics, would make it harder to train with a frozen transformer.

I think I like this suggestion best. A tok2vec version of this is here: https://github.com/explosion/spaCy/pull/11385

* Checking in the listener if the tensors that we are receiving is the same batch ID as we already have. Don't update if we already have the same batch with a backprop function. I thought this is a bit fragile, because it breaks when batching differs between training and prediction (?).

You're right that this might be a bit too fragile.

Another alternative would be to check against None (for transformers) or against _empty_backprop (with tok2vec) within the receive body and don't overwrite the backprop call when the new one is empty, but against that's a bit brittle.

I have a final, more radical proposal:

For v4, remove the indirect communication with listeners entirely for the outputs (you still need it for the backprop), and always rely on doc.tensor, i.e. make it mandatory to put tok2vec and transformer in the annotating_components.

svlandeg avatar Aug 26 '22 13:08 svlandeg

Closing in favour of https://github.com/explosion/spacy-transformers/pull/345

svlandeg avatar Sep 07 '22 15:09 svlandeg