text-generation-inference
text-generation-inference copied to clipboard
Support for extended context for LlaMA Based models
Feature request
Came across this article https://kaiokendev.github.io/til#extending-context-to-8k that suggests by interpolation, we can potentially extend the context for the model.
# These two lines:
self.scale = 1 / 4
t *= self.scale
I think it could be added to the server/text_generation_server/utils/layers.py after Line 283
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
self.scale = 1 / 4 # Add a variable scale
t *= self.scale
Motivation
To extend the model context beyond 2048 for LlaMA based models.
Your contribution
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
self.scale = 1 / 4 # Add a variable scale
t *= self.scale
If I'm not mistaken, this change requires to at least fine tune the model right?
Remarkably, it works on Llama models even without fine-tuning, though you lose some accuracy with niaive models. However, once you finetune with the dilated positional encoding, you can regain most of that accuracy. Even once finetuned though, you'd need these changes present in your inference server, so this would be an awesome feature for TGI. It would best be added as an optional flag (setting the scaling factor) rather than being on by default. You'd also need to update some of the logic around the maximum token limits to properly handle the scaling.
Meta also just published a paper on this: https://arxiv.org/abs/2306.15595 Here is the suggested monkey patch by kaiokendev: https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/dope_llama_monkey_patch.py
The test code you provided would it work to build the server locally and run it? Wondering as id love to try it out
In case this is interesting to anyone I was able to replicate the results of https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ to extend context for LLaMA based models. ( beware of the hacky cargo culting :D )
In case this is interesting to anyone I was able to replicate the results of https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ to extend context for LLaMA based models. ( beware of the hacky cargo culting :D )
That's a different method again! At this stage it's hard to tell which one works better once the model has been finetuned. LMSys just successfully finetuned a model to 16k context length using kaiokendev's original method: https://lmsys.org/blog/2023-06-29-longchat/ . It might be worth waiting a few days to see the results of a finetune with the ntk aware version before implementing one or the other.
I know the method would work with superHOT models that were already finetuned for 8k tokens. @ssmi153 the model that LMsys finetuned on 16k context would work fine without any changes to inference code?
I'm pretty sure that the LMSys model would require the inference code to use dilated rotary embeddings (i.e. changing those two lines of code / using the monkey patch). Note that the monkey patch code linked above does a hard-coded 4x dilation => 8k tokens, so you'd need to extend this to an 8x dilation => 16k tokens to work with the LMSys model. (This is just a matter of changing some of the numbers in the code). Ideally this dilation ratio would be configurable at runtime as a settings flag. (2x, 4x, 8x)
https://github.com/huggingface/text-generation-inference/issues/512 for latest work