streaming-llm icon indicating copy to clipboard operation
streaming-llm copied to clipboard

Progressively decreasing attention windows

Open Vorlent opened this issue 1 month ago • 0 comments

In the spirit of the paper "The Unreasonable Ineffectiveness of the Deeper Layers" (https://arxiv.org/abs/2403.17887v1), it should be possible to have progressively decreasing attention windows without losing any performance whatsoever.

LLama 3 70B has 80 layers in total and a context window of 8k tokens. The idea is that each layer has access to half the context compared to the previous layer with the final layers having some minimum size for the context window.

Layer Context
0 8k
1 4k
2 2k
3 1k
4 512
5 256
6 128
7 to 80 64

This would add up to 8k^2+4k^2+2k^2+1k^2+512^2 + 256^2+128^2+ 76*64^2 vs 8k^2 * 80. The computational cost of prompt processing would drop by 98%.

I came up with the concept of the idea by looking at these diagrams in the paper:

Screenshot from 2024-05-18 17-54-12 Screenshot from 2024-05-18 17-38-30

The initial layers have already arranged all the information in such a way as to make it accessible "locally" within the sliding window. In other words, the model already implements some form of hierarchical attention with the initial layers being responsible for performing the heavy lifting involved with global attention. If the above described optimization is feasible, the need for linear attention mechanisms vanishes into thin air as the initial quadratic attention mechanism is unavoidable for good LLM answering performance.

Choosing whether a token is worth looking at or not, requires some initial sweeping pass. If you have a token and want to know which tokens in the preceding context "resonate" with your token, you will have to do a linear pass. Repeating this linear pass for every generated token results in quadratic attention. There is no obvious way one could avoid this. One could plausibly take the context and turn a group of k tokens into a block, but then it is still necessary to perform block-wise quadratic attention. You merely go from O(n^2) to O((n/k)^2). This is not a big victory compared to a decreasing window, where k gets bigger and bigger with each layer to the point where the first three layers dominate 80%+ of the computation time and the remaining layers contribute almost nothing.

Vorlent avatar May 18 '24 17:05 Vorlent