exllama icon indicating copy to clipboard operation
exllama copied to clipboard

Interesting method to extend a model's max context length.

Open allenbenz opened this issue 1 year ago • 49 comments

https://kaiokendev.github.io/til#extending-context-to-8k Someone had the clever idea of scaling the positional embeddings inversely proportional to the extended context length.

Adding

            self.scale = 1 / 2
            t *= self.scale

after https://github.com/turboderp/exllama/blob/b29960fe8c97796d6363182c7ea302b735b409e4/model.py#L747

did seem to work for me. In a 3800 token context, with a 3615 token prompt, where I had scattered some "secret phrases" at the very start, in the middle and near the end and using guanaco-33b (w/2048 context) was able to extract the phrases from the prompt.

Not suggesting this gets added but it appears to work, though I don't know how much this degrades the perplexity, especially without fine tuning.

allenbenz avatar Jun 21 '23 23:06 allenbenz

I commented on the reddit thread as well, and it does implement perplexity quite a bit.

1:1 - ppl = 6.31
1:2 - ppl = 7.42
1:4 - ppl = 15.84
1:8 - ppl = 105.28

But it remains to be seen what tuning could do to remedy that. I've already tried LoRA tuning to run longer contexts by just extending the position embeddings. It makes the model stop producing garbage after 2048 tokens, but it doesn't make it actually care about any part of the context more than 2048 tokens in the past.

Still, it suggests that just targeting query and value projections in a LoRA is enough to affect how it interprets positional embeddings, so maybe it wouldn't be hard to make it comfortable with compressed embeddings, and then maybe long-range attention just comes out of that for free. It's speculative, but definitely worth trying out.

turboderp avatar Jun 21 '23 23:06 turboderp

I've gotten coherent summaries of 3400 tokens worth of text, so it uses the whole context or at least if the model is being blinded it's more subtle.

I ran your perplexity benchmark with the gptq-for-llama option on guanaco-33B-4bit-128g: 1:1 = 4.3951 1:2 = 4.9279 1:4 = 9.2616 1:8 = 47.4795

I know the absolute numbers can't be compared but is the same true with the relative difference? If we can then this might be another thing where the larger parameter models are more robust against degradation.

Anyways, this isn't really important until we see fine tunes like you've mentioned. It's just interesting that this change works at all with models that haven't been trained on the adjusted positional embeddings.

allenbenz avatar Jun 22 '23 02:06 allenbenz

Played around with it more. The context blindness is diffused over the entire context. Having the model echo large prompts shows it messing up capitalization and dropping/shuffling tokens.

So until someone makes a lora to try this is no good for "rewrite the following." Though it is still useful for summaries.

allenbenz avatar Jun 22 '23 03:06 allenbenz

I have to add, I'm testing with 2x4090 on 65B and newest NVIDIA drivers (which automatically uses RAM when more VRAM is needed)

I'm not sure how to do the perplexity test, but I tested the code on a 65B model

self.scale = 1 / 2
t *= self.scale

Using ooba webui, and exllama

Output generated in 7.56 seconds (3.31 tokens/s, 25 tokens, context 3065, seed 100932143)
Output generated in 12.09 seconds (10.67 tokens/s, 129 tokens, context 3093, seed 1485169274)
Output generated in 9.82 seconds (15.38 tokens/s, 151 tokens, context 3231, seed 1739546440)
Output generated in 7.97 seconds (15.18 tokens/s, 121 tokens, context 3385, seed 706103506)
Output generated in 3.40 seconds (1.47 tokens/s, 5 tokens, context 3246, seed 1569916475)
Output generated in 4.00 seconds (3.50 tokens/s, 14 tokens, context 3254, seed 818945043)
Output generated in 1.99 seconds (10.06 tokens/s, 20 tokens, context 3277, seed 845093379)
Output generated in 2.39 seconds (0.42 tokens/s, 1 tokens, context 3300, seed 865878355)
Output generated in 1.05 seconds (7.60 tokens/s, 8 tokens, context 3310, seed 1002252356)
Output generated in 0.88 seconds (9.11 tokens/s, 8 tokens, context 3321, seed 1499653030)
Output generated in 0.74 seconds (6.77 tokens/s, 5 tokens, context 3338, seed 1709479689)
Output generated in 2.74 seconds (14.60 tokens/s, 40 tokens, context 3346, seed 1332302207)
Output generated in 0.95 seconds (7.38 tokens/s, 7 tokens, context 3395, seed 422863456)
Output generated in 9.93 seconds (14.81 tokens/s, 147 tokens, context 3405, seed 370547489)

Which IMO are pretty good speeds. I have 64GB DDR5 6400Mhz RAM.

It is using 48 GB of VRAM + ~3-4 GB of RAM. (Seems to be changing based on context size)

I also had to edit self.max_seq_len = 2048 to self.max_seq_len = 4096

Tomorrow I can test 8192 context, but it would surely use swap (I have 200GB for swap), so there it would be a good amount slower.

Panchovix avatar Jun 22 '23 06:06 Panchovix

For perplexity you can run test_benchmark_inference.py with the -ppl option. Pretty sure turboderp himself tested this scaling yesterday and found perplexity to massively increase in exchange for context length.

jmoney7823956789378 avatar Jun 22 '23 11:06 jmoney7823956789378

Yep. I only tested how scaling affects the first part of the context, though. I didn't test perplexity over the full expanded context.

turboderp avatar Jun 22 '23 11:06 turboderp

@turboderp just a quick thought, as for chat mode, could the user input be 1 scale and every chat history after be 1/2 scale and the next being smaller and smaller, ever concatenating to include all history in the full 2048 tokens but with every level further back becomes more bluryeyed just like us? maybe even just a smooth, non-linear, curve with each token further having an ever smaller identifier fraction or maybe even an unsmooth curve filling in the 2048 with the most contextually related tokens having closer to a whole token identifier and the less related having smaller fractions. When I picture it, to me, it looks like gravity. The basic primes being that we fill the llm max context packed with data with most contextually important tokens appearing "closer" or in "bigger" font to the llm and the less important, "further" or "smaller". This brings the best of vector searches and variable attention together in a single endpoint, all built in. Do you get me?

ghost avatar Jun 22 '23 15:06 ghost

I actually really like the idea of the dynamically scaling context, it somewhat emulates human memory. I think if we could do dynamic compression of the conversation in a similar manner (ie. use some LLM or whatever to compress conversation history in a way that recent context is less compressed and past context is more compressed) then that + the dynamic context length could give us a perceived much larger context (even if we cap context to like 6k or whenever perplexity becomes too high for us) we can prob have it feel like 10-15k depending on how we compress previous conversation history which I think is really cool. If I get some time I'll work on a prototype as a PoC

That being said, @turboderp is it even possible to change scale etc. during model run-time/inference? I lack context on the codebase specifics so this might just be a pipedream

nikshepsvn avatar Jun 22 '23 17:06 nikshepsvn

@Jeduh I'm still running tests at the moment, but dynamically scaling the embeddings, I think, would work poorly. The tuned model has some tolerance, but ideally you would tune it to a new scale and then use that throughout. Keep in mind that you're not really getting "more resolution" at the 1:1 scale. It's not a compression of the actual value content of the tokens as in an RNN.

@nikshepsvn Currently I'm experimenting with scaling the embeddings by a fixed amount, but it could be scaled by some non-linear function instead. One problem with that would be building on a cached sequence. You would have to basically reevaluate the whole sequence every time the positional embeddings move around.

I think it's all very premature to start thinking about that. There's no reason to assume a constant scaling factor actually has any drawbacks that need to be overcome with an approach like that.

turboderp avatar Jun 22 '23 18:06 turboderp

@turboderp it's definitely the smart option to train on 1/2 scale and 1/4 and so on so that these models are familiar with the linear scaled. But if its already been trained on those scaled, whats stopping us using both or all scales, fine tuned for, in one prompt. For example:

Original "By breaking down complex sentences into smaller chunks, hierarchical self-attention allows the model to focus on local patterns while still capturing global contextual information. Essentially, it helps balance the tradeoff between efficiency and expressiveness inherent in many neural network designs."

With Hierarchical self-attention "By breaking down complex sentences into smaller chunks, hierarchical self-attention allows the model to focus on local patterns while still capturing global contextual information."

Leaving "Essentially, it helps balance the tradeoff between efficiency and expressiveness inherent in many neural network designs."

Could we not use the HSA prompt at 1:1 scale and the left overs from the original at 1:2 scale, assuming the llm has an incredibly small context size in this case lol.

ghost avatar Jun 22 '23 18:06 ghost

Well, like I said, you're not compressing the content of the context. It's not like it has a fuzzier recollection of tokens when their positional embeddings are closer together. It's just a question of how far apart it expects tokens to be in relation to each other. 10 tokens with a spacing of 1 still produce the same value output as 10 tokens with a spacing of 1/4, if the queries can find the right keys.

Teaching the model that tokens can have a spacing of either 1, 1/2, 1/4 or 1/8 seems needlessly difficult compared to just saying, "hey, the spacing is 1/8 now, deal with it."

turboderp avatar Jun 22 '23 18:06 turboderp

To me, it sounds like quantisation vs sparse quantise. But I understand how it acts as a float coordination system instead of compressing tokens, and how we don't even know if having a normal scale works first, Chatbort told me :) Got a little excited thats all haha. Thank you for your time @turboderp

ghost avatar Jun 22 '23 18:06 ghost

I finished some more thorough tests now, and it's actually kind of promising. Perhaps @kaiokendev would be interested as well:

superhot_test

This is running a perplexity test on a number of 8k sequences from WikiText2, truncated at varying lengths. The rationale is that perplexity should always be better on longer sequences, since as each sequence grows the model has more and more of a past to base its predictions on. But only if it's actually attending to that past correctly.

"Base" in red is the base Llama-13B with no LoRA applied, and as expected it goes mental as soon as it is asked to process sequences longer than 2048 tokens.

The blue line is the same model with SuperHOT applied, positional embeddings compressed by a factor of 4, but all other parameters kept the same. It has slightly worse perplexity overall but this is not unusual for finetunings. It usually comes down to the choice of dataset and how closely it aligns with the base models pretraining.

The remarkable thing, of course, is that perplexity keeps improving all the way to 8k tokens, meaning the model is attending to the full sequence. The test dataset doesn't necessarily have a whole lot of 8k-token passages where there is enough long-range content to utilize, so further testing should probably be on a different dataset, something like literature or long news articles. Either way, the evidence is clear that this works, at least to some extent.

And as to that, the LoRA is trained with a very low rank, which might be holding it back. Also, only 1200 examples of which only 400 are longer than 2048 tokens. So there could be room for improvement there. It would also be interesting to try larger models, since 13B is substantially worse at dealing with even a 2k context than 33B.

I have also included, in yellow, the results from a LoRA I trained with examples up to 6k tokens, merely extending the positional embeddings. It avoids the total collapse after 2048 tokens, but it doesn't do terribly well overall. I have speculated that what it needed was more training, and that may still be correct, but it might need an unrealistic amount to overcome the pretraining. And the interpolation looks much more promising anyway.

turboderp avatar Jun 22 '23 19:06 turboderp

Kinda blows my mind that something this simple could potentially have tremendous impact in the open source LLM community, goes to show how many "low hanging fruit" there are for us to catch and optimize/build based on. Dope work on the testing @turboderp, the graph and ppl decrease with context increase is extremely promising even given the fact that there is an initial hit to ppl

nikshepsvn avatar Jun 22 '23 19:06 nikshepsvn

@turboderp that's a pretty cool graph, i had a thought though, rn the scaling factor is "linear", do you think there could be some gains by using other scaling functions ? for example have the spacing between tokens be 0.5 at the most recent context and let's say 0.25 at the begining/older context, if it was a distribution instead of a linear scale, maybe we could teach our models to put more emphasis on more recent information? or maybe it's a really bad idea and will just always make things worse, i just wondered if it could have a desirable effect.

alkeryn avatar Jun 22 '23 19:06 alkeryn

@alkeryn there is a discussion that is almost exactly this above, feel free to read

nikshepsvn avatar Jun 22 '23 19:06 nikshepsvn

@nikshepsvn oh yea i missed that, either way it is kind of crazy that such a simple thing could have such an impact on the technology and that we only figured it out now. at that point does native context length even matters if you can fine tune it to any ratio. though, it still eat vram like candy but to my testing the usage seemed more linear than quadratic.

alkeryn avatar Jun 22 '23 19:06 alkeryn

also @turboderp couldn't we fine tune models to work with arbitrary spacing (for example constantly changing the scale during training)? if not, just tuning them for 4k then 8k then 16k etc might be enough to get where we want, vram will probably be the limit anyway, if we figure a way to also optimize that down a lot we may be able to reach a point where context becomes irrlevent which would be cool.

alkeryn avatar Jun 22 '23 20:06 alkeryn

@turboderp Thank you for doing the test, it is good now to have hard numbers to back up the visible effect. I will be training 30B today and also test with 16K with a scale of 0.125, I was also asked to test MPT but I cannot fit 30B since its not quantized, so I will perform the scaling for MPT-7B. The rationale I believe is that this may not solely apply to rotational encodings, but any relative encoding, including ALiBi which notoriously increasing in perplexity after 2x the training length.

kaiokendev avatar Jun 22 '23 20:06 kaiokendev

@turboderp in that graph, "Base" is the base model, right? So the perplexity gap between RoPE and Base is likely due to finetuning on a different dataset. kaiokendev also released a plain 2k model without the compression, which better for apples-to-apples comparisons.

I think the key question here is: how much of the original model's performance at <2K is recovered from finetuning? If there is significant loss, then this trick is a sidegrade; much better for long sequences, but could be worse for those who can't fit long sequences in vram anyway. If there's no significant loss after finetuning, then this is a miracle, and pretty much every finetune might as well support 8K+ from now on. In fact, even pre-existing finetunes could be backported by slapping an 8K+ lora on top.

Either way, we're left with a harder question: how the hell to fit the KV cache in VRAM? Is CPU offloading going to become critical soon, or will there be some way to quantize/compress it? I'm hopeful that non-uniform quantization might help, something like "keep latest 256 tokens at fp16, quantize the remainder".

QM60 avatar Jun 22 '23 20:06 QM60

@kaiokendev man all this stuff makes me want to be able to work full time on it as my job. anyway thanks for the time you put into this work.

alkeryn avatar Jun 22 '23 20:06 alkeryn

Are we essentially directing models perceptions towards having multiple tokens per positional embed? Would that mean we can fine tune models to almost think in terms of sentences per positional embed token?

ghost avatar Jun 22 '23 20:06 ghost

@alkeryn I think it's a little premature to start demanding that the model understand multiple scales, before there's anything to suggest it needs more than one scale.

@kaiokendev I noticed when applying the LoRA that, even though you're targeting bias all the bias tensors in the adapter are empty. I'm also wondering if rank 2 isn't maybe needlessly low? The LoRA would still be tiny at a rank 4 or 8. Is it just to accelerate the finetuning?

turboderp avatar Jun 22 '23 20:06 turboderp

What if we make positional embedding token 2049 or more the identifying key for "hey, the spacing is 1/8 now, deal with it." with just "8" or something.

ghost avatar Jun 22 '23 20:06 ghost

@QM60 I'm not really having trouble running 8k contexts for 13B. But for 33B, yes, it's going to be trickier. I do have a second 24 GB GPU, luckily. So that helps.

Offloading the KV-cache to system RAM would kill performance for sure. The keys at least need to be present for attention. Perhaps you could offload values and just swap in values with an attention weight above some threshold... but idk. I'll be happy if this gives even 4k tokens reliably.

turboderp avatar Jun 22 '23 20:06 turboderp

@turboderp Heh, actually I noticed it too, seems exporting the bias does nothing! I'm not sure if it's a bug in peft, because I didn't see anything in the paper that implied LoRA should not apply to the bias, but that is a different problem for a different day.

For the rank, I kept it as 2 because the original paper demonstrates that q, k, v, o at rank 1 performs better/on-par with q,v at rank 8, but the memory size is increased significantly, so I just kept it at 2, since it should perform as well as q,v at rank 8

kaiokendev avatar Jun 22 '23 20:06 kaiokendev

@Jeduh You're still teaching the model two different behaviors that have to coexist. Much harder than just modifying one existing behavior. And you need some kind of rationale anyway. What would be the benefit?

turboderp avatar Jun 22 '23 20:06 turboderp

What if we make positional embedding token 2049 or more the identifying key for "hey, the spacing is 1/8 now, deal with it."

Interpolation effect is intended as an alternate way to achieve length extrapolation. The yellow line in the graph is what happens if you try to teach the model 2049+x. Interpolation means to teach the model between [0, 2048], since it is more overfit on that range. Technically, both achieve the same thing, but interpolation requires far, far fewer compute

kaiokendev avatar Jun 22 '23 20:06 kaiokendev

Thank you @turboderp @kaiokendev , I just think interpolating token codes past 2048 add a form of trainable and trigger-able cheat codes to whatever other methods/mods we come across down the line. Eg, selecting 1/8 scale with [2049]=8 or 1/16 with [2049]=g. Or maybe even go as far as coded personalities. But I am basing this off intuition alone rather than logical reasoning, to which I really respect you, @turboderp , for being rational, bringing us all here :)

ghost avatar Jun 22 '23 21:06 ghost

@turboderp You think you can copy your test results here? https://github.com/ggerganov/llama.cpp/discussions/1965

kaiokendev avatar Jun 22 '23 21:06 kaiokendev