exllama
exllama copied to clipboard
Strange behavior with caching on 8K models
Using text-gen webui with Exllama loader gives me different results than with Exllama_HF. Specifically, Exllama_HF gives gibberish with SuperHOT 8K models past 2048 tokens. Even the logits of the two loaders produce completely different values. After looking at the code, it doesn't seem like ooba is doing anything wrong. I narrowed it down to this caching code:
https://github.com/turboderp/exllama/blob/master/generator.py#L173
if in_tokens.shape[-1] > 1:
a = 0
while a < self.sequence.shape[-1] - 1:
b = min(a + max_chunk, self.sequence.shape[-1] - 1)
self.model.forward(self.sequence[:, a:b], self.cache, preprocess_only = True, lora = self.lora)
a = b
After setting max_chunk to 8192, the Exllama loader also gives gibberish results. I applied this caching code to the Exllama_HF and set the chunk size to 2048:
seq = kwargs["input_ids"][0].tolist()
cache = kwargs["past_key_values"] if "past_key_values" in kwargs else None
if cache is None:
cache = ExLlamaCache(self.ex_model)
nseq = seq[:-1]
for seqs in [nseq[i : i + 2048] for i in range(0, len(nseq), 2048)]:
self.ex_model.forward(
torch.tensor([seqs], dtype=torch.long),
cache,
preprocess_only=True,
lora=self.lora,
)
With this, the Exllama_HF gives coherent output past 2048, even 8K, like I would expect. Clearly though, this is only a downstream fix. Something is wrong with the caching strategy that it is not able to handle chunks > 2048 properly.
I'm already working on optimizing the implementation to work better on the longer contexts. One of the changes is to automatically prevent attention operations from scaling too wildly, by doing the same chunking in the base model as I was doing in the generator. I was going to do a little more testing, but I guess I can just push it now if there are issues with ExLlama_HF.
The underlying issue is CUDA related, I think. Still not sure exactly which kernel is failing when the input becomes too large, but this should keep it under control at least.
Thanks @kaiokendev for the fix in the meanwhile! I was getting really worse results with exllama_hf vs exllama on the ooba webui, and I couldn't find why.