yoyodyne icon indicating copy to clipboard operation
yoyodyne copied to clipboard

Add caching for transformer inference

Open Adamits opened this issue 1 year ago • 3 comments

Transformer inference (i.e. with no teacher forcing) is slow. In practice I think people typically implement some kind of caching so that at each timestep, we do not need to recompute the embeddings and attentions between all previously decoded timesteps.

I have a quick and dirty implementation of this in an experimental fork, where I basically tell the decoder layer to only get the attention from the most recently decoded target, and all other representations are concatenated on. There are probably other tricks that I can find by e.g. inspecting some huggingface transformers inference code.

I propose adding an option to the transformer encoder decoders to use caching, wherein a CacheTransformerDecoder module is used.

This is a low-priority TODO since we do validation with accuracy, rather than loss, and accuracy can be reliably predicted with teacher forcing if the targets are provided.

Adamits avatar Nov 16 '23 19:11 Adamits