yoyodyne
yoyodyne copied to clipboard
Add caching for transformer inference
Transformer inference (i.e. with no teacher forcing) is slow. In practice I think people typically implement some kind of caching so that at each timestep, we do not need to recompute the embeddings and attentions between all previously decoded timesteps.
I have a quick and dirty implementation of this in an experimental fork, where I basically tell the decoder layer to only get the attention from the most recently decoded target, and all other representations are concatenated on. There are probably other tricks that I can find by e.g. inspecting some huggingface transformers inference code.
I propose adding an option to the transformer encoder decoders to use caching, wherein a CacheTransformerDecoder module is used.
This is a low-priority TODO since we do validation with accuracy, rather than loss, and accuracy can be reliably predicted with teacher forcing if the targets are provided.