text-generation-inference icon indicating copy to clipboard operation
text-generation-inference copied to clipboard

Support for extended context for LlaMA Based models

Open keelezibel opened this issue 2 years ago • 7 comments
trafficstars

Feature request

Came across this article https://kaiokendev.github.io/til#extending-context-to-8k that suggests by interpolation, we can potentially extend the context for the model.

 # These two lines:
self.scale = 1 / 4
t *= self.scale

I think it could be added to the server/text_generation_server/utils/layers.py after Line 283

t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
self.scale = 1 / 4    # Add a variable scale
t *= self.scale

Motivation

To extend the model context beyond 2048 for LlaMA based models.

Your contribution

t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
self.scale = 1 / 4    # Add a variable scale
t *= self.scale

keelezibel avatar Jun 28 '23 02:06 keelezibel

If I'm not mistaken, this change requires to at least fine tune the model right?

OlivierDehaene avatar Jun 28 '23 09:06 OlivierDehaene

Remarkably, it works on Llama models even without fine-tuning, though you lose some accuracy with niaive models. However, once you finetune with the dilated positional encoding, you can regain most of that accuracy. Even once finetuned though, you'd need these changes present in your inference server, so this would be an awesome feature for TGI. It would best be added as an optional flag (setting the scaling factor) rather than being on by default. You'd also need to update some of the logic around the maximum token limits to properly handle the scaling.

Meta also just published a paper on this: https://arxiv.org/abs/2306.15595 Here is the suggested monkey patch by kaiokendev: https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/dope_llama_monkey_patch.py

ssmi153 avatar Jun 29 '23 07:06 ssmi153

The test code you provided would it work to build the server locally and run it? Wondering as id love to try it out

Ichigo3766 avatar Jun 29 '23 21:06 Ichigo3766

In case this is interesting to anyone I was able to replicate the results of https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ to extend context for LLaMA based models. ( beware of the hacky cargo culting :D )

evq avatar Jun 29 '23 21:06 evq

In case this is interesting to anyone I was able to replicate the results of https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ to extend context for LLaMA based models. ( beware of the hacky cargo culting :D )

That's a different method again! At this stage it's hard to tell which one works better once the model has been finetuned. LMSys just successfully finetuned a model to 16k context length using kaiokendev's original method: https://lmsys.org/blog/2023-06-29-longchat/ . It might be worth waiting a few days to see the results of a finetune with the ntk aware version before implementing one or the other.

ssmi153 avatar Jun 30 '23 01:06 ssmi153

I know the method would work with superHOT models that were already finetuned for 8k tokens. @ssmi153 the model that LMsys finetuned on 16k context would work fine without any changes to inference code?

keelezibel avatar Jun 30 '23 01:06 keelezibel

I'm pretty sure that the LMSys model would require the inference code to use dilated rotary embeddings (i.e. changing those two lines of code / using the monkey patch). Note that the monkey patch code linked above does a hard-coded 4x dilation => 8k tokens, so you'd need to extend this to an 8x dilation => 16k tokens to work with the LMSys model. (This is just a matter of changing some of the numbers in the code). Ideally this dilation ratio would be configurable at runtime as a settings flag. (2x, 4x, 8x)

ssmi153 avatar Jun 30 '23 02:06 ssmi153

https://github.com/huggingface/text-generation-inference/issues/512 for latest work

arnocandel avatar Jul 20 '23 21:07 arnocandel