spacy-transformers icon indicating copy to clipboard operation
spacy-transformers copied to clipboard

Transformer.predict: do not broadcast to listeners

Open danieldk opened this issue 3 years ago • 0 comments

The output of a transformer is passed through in two different ways:

  • Prediction: the data is passed through the Doc._.trf_data attribute.
  • Training: the data is broadcast directly to the transformer's listeners.

However, the Transformer.predict method breaks the strict separation between training and prediction by also broadcasting transformer outputs to its listeners.

However, this breaks down when we are training a model with an unfrozen transformer when the transformer is also in annotating_components. The transformer will first (as part of its update step) broadcast the tensors and backprop function to its listeners. However, then when acting as an annotating component, it would immediately override its own output and clear the backprop function. As a result, gradients will not flow into the transformer.

This change removes the broadcast from the predict method. If a listener does not receive a batch, attempt to get the transformer output from the Doc instances. This makes it possible to train a pipeline with a frozen transformer.

This ports https://github.com/explosion/spaCy/pull/11385 to spacy-transformers. Alternative to #342.

danieldk avatar Aug 31 '22 12:08 danieldk