exllama icon indicating copy to clipboard operation
exllama copied to clipboard

3-bit and 2-bit GPTQ support

Open TechnotechGit opened this issue 2 years ago • 23 comments
trafficstars

Hi! While 3-bit and 2-bit quantisations are obviously less popular than 4-bit quantisations, I'm looking into the possibility of loading 13B models with 8 GB of VRAM. So far, loading a 3-bit 13B model is possible with AutoGPTQ, but it OOMs around 1k tokens. With ExLlama's speed and memory efficiency, I would imagine that a 3-bit 13B model (or 2-bit if really needed) could be quite viable for those of us with less VRAM.

TechnotechGit avatar Jun 22 '23 08:06 TechnotechGit

How about 8-bit support. To run a 30b at higher perplexity in multi-gpu, except fast.

Ph0rk0z avatar Jun 23 '23 14:06 Ph0rk0z

8-bit wouldn't be faster than 4-bit. And the perplexity wouldn't be much better, either, at least for large models. You'd be better off with one of the sparse methods.

turboderp avatar Jun 23 '23 15:06 turboderp

But the options nowadays to run 8bit are using GPTQ for llama or AutoGPTQ, which are a good amount slower than exllama. (at least for multiGPU)

There's also bitsandbytes, but in that case is painfully slow.

Perpexplity is not that much better but a bit, bit better, which I feel it may considerable.

Thogh, if you feel it is not enough used and such, it's fine. I can imagine it would involve a good amount of work.

Besides 8bit, I think 3 bits may be pretty feasible. 2bit I'm not sure, it works "fine" on llamacpp but I would guess its't not the case for gptq.

Panchovix avatar Jun 23 '23 18:06 Panchovix

4 bits is about the sweet spot for Llama where you get decent enough perplexity and also room for larger models, which has a much greater impact than quantization size. Going to 3 bits alone is a big drop in quality, but it gets interesting when you only quantize select, small parts of the model with more precision. It turns out there are hotspots, like a few rows in each matrix (out of thousands) that have a disproportionate impact. So if you add those in at full precision they can more than make up for the difference between 3 and 4 bits.

I am working on something similar, although it's a low priority at the moment. In the meantime if you want to run slower, bigger and more precise models, you could also consider 32g act-order.

turboderp avatar Jun 23 '23 20:06 turboderp

4 bits is about the sweet spot for Llama where you get decent enough perplexity and also room for larger models, which has a much greater impact than quantization size. Going to 3 bits alone is a big drop in quality, but it gets interesting when you only quantize select, small parts of the model with more precision. It turns out there are hotspots, like a few rows in each matrix (out of thousands) that have a disproportionate impact. So if you add those in at full precision they can more than make up for the difference between 3 and 4 bits.

I am working on something similar, although it's a low priority at the moment. In the meantime if you want to run slower, bigger and more precise models, you could also consider 32g act-order.

Thanks! Just tested some 32g models and it works.

But just tied LLaMA-65b-32g (https://huggingface.co/Neko-Institute-of-Science/LLaMA-65B-4bit-32g) and I got

(venv) PS F:\ChatIAs\exllama> python webui/app.py -d 'F:\ChatIAs\oobabooga\text-generation-webui\models\Neko-Institute-of-Science_LLaMA-65B-4bit-32g' -gs 16,21
 -- Tokenizer: F:\ChatIAs\oobabooga\text-generation-webui\models\Neko-Institute-of-Science_LLaMA-65B-4bit-32g\tokenizer.model
 -- Model config: F:\ChatIAs\oobabooga\text-generation-webui\models\Neko-Institute-of-Science_LLaMA-65B-4bit-32g\config.json
 -- Model: F:\ChatIAs\oobabooga\text-generation-webui\models\Neko-Institute-of-Science_LLaMA-65B-4bit-32g\4bit-32g.safetensors
 -- Sequence length: 2048
 -- Tuning:
 -- --matmul_recons_thd: 8
 -- --fused_mlp_thd: 2
 -- --sdp_thd: 8
 -- Options: ['gpu_split: 16,21']
 -- Loading model...
Traceback (most recent call last):
  File "F:\ChatIAs\exllama\webui\app.py", line 142, in <module>
    model = ExLlama(config)
  File "F:\ChatIAs\exllama\model.py", line 709, in __init__
    device = self.config.device_map.map(key, loading = True)
  File "F:\ChatIAs\exllama\model.py", line 585, in map
    return self.layers[num]
IndexError: list index out of range

I'm gonna do a requant of this model myself and try again.

Panchovix avatar Jun 23 '23 20:06 Panchovix

I've tested that particular model, and it should work. I run it with -gs 17.2,24 though. The error might be because it gives up on loading the entire model with 16+21 GB. The file is only 35.9 GB, but the layers are huge and you won't necessarily get that exact split.

turboderp avatar Jun 23 '23 20:06 turboderp

I tried with different size variations and I get the same issue :/ but I'm gonna requant myself and see how it goes.

Panchovix avatar Jun 23 '23 20:06 Panchovix

Wait a minute, I know what this is, cause I had the same issue come to think of it. The config.json file on HF is wrong for that model. They accidentally included a 33B config file in that folder.

Try replacing it with this one from the 128g version. Should work just as well.

turboderp avatar Jun 23 '23 20:06 turboderp

Wait a minute, I know what this is, cause I had the same issue come to think of it. The config.json file on HF is wrong for that model. They accidentally included a 33B config file in that folder.

Try replacing it with this one from the 128g version. Should work just as well.

~~You're right, it works now, thanks! I'm gonna do a quant now then of 32g and act order = True, and upload it to get better perplex.~~

Already has act order, so it's not needed. Just a complete model, nice!

Panchovix avatar Jun 23 '23 21:06 Panchovix

Going to 3 bits alone is a big drop in quality

If a 13B model at 3-bit could fit into VRAM, it'd probably still be better than a 7B model right?

TechnotechGit avatar Jun 24 '23 04:06 TechnotechGit

When i did perplexity tests for 8bit vs 4bit on 13b models, the perplexity difference was somewhat sizable. Granted I used bnb instead of GPTQ at the time. Almost a whole point higher.

Ph0rk0z avatar Jun 24 '23 13:06 Ph0rk0z

@Ph0rk0z I'm not sure what quantization bnb uses, but if it's just RTN then yeah, there's going to be a big difference between 4-bit and 8-bit. GPTQ is a bit more sophisticated and even at 4 bits already gets pretty close to the perplexity of the original FP16 weights. Good comparison here

@TechnotechYT 13B at 3-bit will probably perform better than 7B at 4-bit, yes. I think it would still be a tight squeeze in 8 GB, especially if you have something else using VRAM at the same time, like your OS. I'm also a hesitant to spend too much effort on very specific use cases when there are so many other things waiting to get done.

turboderp avatar Jun 24 '23 19:06 turboderp

FWIW I'm able to run 3bit x 65b LLaMa on a single 32gb GPU using AutoGPTQ which is kinda neat and it seems to be close to 65b q4 in terms of quality (haven`t run benchmarks), so the speedup from exllama would be very nice for that usecase. Just chiming in to say that there are cases where this would be nice, but very much respect that there are probably not many of us.

thot-experiment avatar Jun 25 '23 00:06 thot-experiment

@TechnotechYT 13B at 3-bit will probably perform better than 7B at 4-bit, yes. I think it would still be a tight squeeze in 8 GB, especially if you have something else using VRAM at the same time, like your OS. I'm also a hesitant to spend too much effort on very specific use cases when there are so many other things waiting to get done.

Appreciate the work, ExLlama is great to use on 7B. 3-bit is definitely used less than 4-bit, so I can appreciate that there are definitely more important things.

TechnotechGit avatar Jun 25 '23 04:06 TechnotechGit

I was testing 13b in 8bit bnb vs 4bit GPTQ. I see he's put up newer benchmarks but only at 128g (and no mention of act order or not). I mostly haven't used group size since previously that made 33b OOM at full context within 24g. With act order too, it would kill performance in cuda. Simply haven't tested if that's different for exllama now and most of my models lack it. Cool if it is, but then this will be the only good way to run inference.

More or less that page tells me that bnb-4bit is a lost cause and why it was never compared to GPTQ in the paper. GPTQ 4bit lora targeted at all layers vs the 2 would completely remove the need for it. Still.. he never tested 8bit or FP16 for those larger models and we are to assume that the spread is the same as the 7b.

Ph0rk0z avatar Jun 25 '23 11:06 Ph0rk0z

Groupsize has a negligible impact on performance, and the extra file size doesn't prevent 33B models from using full context on 24 GB. Act-order has a small impact on speed but not much:

Model Size grpsz act Seq. len. VRAM Prompt Best Worst Ppl
Llama 33B 128 no 2,048 t 20,795 MB 2,959 t/s 47 t/s 40 t/s 4.60
Llama 33B 128 yes 2,048 t 20,795 MB 2,784 t/s 45 t/s 37 t/s 4.55

Act-order kills performance on GPTQ-for-LLaMa and AutoGPTQ because it assigns a quantization group to every row of each matrix, out-of-order. So instead of reading the quantization parameters once every 128 rows (for group size 128), they have to be read once per row instead. These memory accesses also coalesce poorly.

ExLlama gets around it by turning act-order matrices into regular groupsize matrices when loading the weights (code) and does the reordering on the other side of the matrix multiplication (code) to get the same result anyway.

I don't think this would be too difficult to port over to AutoGPTQ either. It's really just those two functions, like 100 lines of code in total. For 4-bit, anyway. 2, 3 and 8 bit would all need their own versions.

turboderp avatar Jun 25 '23 18:06 turboderp

AutoGPTQ/GPTQ have to maintain compatibility for stuff like P40, I'm not sure how/if this affects it. In essence you're saying that you have fixed both group size and act order for memory use and speed. So it's nice that we can enjoy the perplexity bump on faster cards. I will have to try some models like this and compare because I've been avoiding them.

Here we can lose t/s and be way ahead. Over there it's a way different story.

Ph0rk0z avatar Jun 26 '23 12:06 Ph0rk0z

Groupsize has a negligible impact on performance, and the extra file size doesn't prevent 33B models from using full context on 24 GB. Act-order has a small impact on speed but not much: Model Size grpsz act Seq. len. VRAM Prompt Best Worst Ppl Llama 33B 128 no 2,048 t 20,795 MB 2,959 t/s 47 t/s 40 t/s 4.60 Llama 33B 128 yes 2,048 t 20,795 MB 2,784 t/s 45 t/s 37 t/s 4.55

Act-order kills performance on GPTQ-for-LLaMa and AutoGPTQ because it assigns a quantization group to every row of each matrix, out-of-order. So instead of reading the quantization parameters once every 128 rows (for group size 128), they have to be read once per row instead. These memory accesses also coalesce poorly.

ExLlama gets around it by turning act-order matrices into regular groupsize matrices when loading the weights (code) and does the reordering on the other side of the matrix multiplication (code) to get the same result anyway.

I don't think this would be too difficult to port over to AutoGPTQ either. It's really just those two functions, like 100 lines of code in total. For 4-bit, anyway. 2, 3 and 8 bit would all need their own versions.

@turboderp One issue I see with the exllama implementation is that it can't work with row tensor parallelism, can it? Since in this case the activation is split over the K dimension, the reordering column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); is bound to fail as part of the data is on an other GPU. Am I correct? Do you see a way around it?

fxmarty avatar Jul 05 '23 08:07 fxmarty

If you mean tensor parallelism as in torch.distributed.tensor.parallel then no, but I don't see any other GPTQ implementations using that specific API either. But in principle reordering the rows doesn't prevent you from splitting the matmul across GPUs. You've still got the same bandwidth constraints regardless.

Considering the matmul A @ B, what ExLlama performs for act-order models is instead s_col(A) @ s_row(B). s_row is applied at load time and sorts the rows in the quantized matrix by their group index, which in GPTQ terms turns it a "no-act-order" matrix. s_col performs the corresponding reordering of columns in A such that s_col(A) @ s_row(B) = A @ B.

The matmul itself is still just a matmul, so whatever method you would use to split it across devices would still work. The s_col transformation needs to be applied to the left-hand operand first, but this is just an added step and shouldn't matter for parallelizing the matmul afterwards. Depending on how you split the operation, it may need to be performed on multiple devices, but it's still incredibly cheap compared to the GPTQ act-order approach of updating quantization parameters for every single row. Remember, A is really small most of the time: [1, hidden_dim], and B is huge: [hidden_dim, hidden_dim] or [hidden_dim, intermediate_dim]

Parallelism in general is kind of problematic, though. The overhead is considerable and the matmuls are small and numerous, so much so that Python itself easily ends up becoming the bottleneck, as the CUDA queue keeps running out while the CPU is busy doing "free" operations like view(), or building dictionaries of function arguments and whatnot.

turboderp avatar Jul 05 '23 11:07 turboderp

Thank you! Yes, s_row(B) is fine as done ahead of time (weights), but in the row tensor parallelism (that typically follow a column tensor parallel operation) case the activation A is split over different GPUs over its columns, thus a reordering s_col(A) would require an AllGather operation. But that's true that the activation is usually small (at least for small batch sizes), so the overhead may not be that huge.

fxmarty avatar Jul 05 '23 13:07 fxmarty

But how do you avoid gathering in any case? ~~Isn't the fundamental problem still the same, that if you split A in rows you need to split B in columns, and vice versa?~~ Nope, I'm tired, that's not it. But either way, the product you end up with is split across devices in such a way that it can't be forwarded to another matmul (or normalization or whatever the next operation is) without recombining those pieces.

I admit I haven't looked closely at the implementation in Torch, though. And I'd like to be wrong about this, I'm just not seeing any way you could keep the hidden state split throughout the inference. Maybe in places, like a matmul leading into a summation or activation function, but you'd still be synchronizing most of the time. Unless I'm just not up to speed.

turboderp avatar Jul 05 '23 13:07 turboderp

@turboderp You can usually avoid a gather of the activation inbetween a column tensor parallel linear and row tensor parallel linear, see the shapes on the figure here: https://huggingface.co/docs/transformers/v4.30.0/en/perf_train_gpu_many#tensor-parallelism

fxmarty avatar Jul 05 '23 13:07 fxmarty

I see. I will have to read that a little more closely, but going by the MLP example for instance, they aren't splitting the state (X here), they're cloning it to two devices. Then they perform X @ A1 and X @ A2, where A1 and A2 both have the same number of rows as A. Using the reordering approach from ExLlama, Y1 and Y2 would both be valid.

Multiplying by B is the point where you would need another gathering operation, I guess, if you don't just keep the group index for B and use the AutoGPTQ approach there.

An earlier version of ExLlama did maintain the group index after reordering, which I figured at the time was necessary because there was no guarantee the reordered matrix would work like a groupsize matrix with a constant group size. As it turned out, the group size was always constant, due to some implementation choices in GPTQ-for-LLaMa and AutoGPTQ. The reordered index was always just [0] * 128 + [1] * 128 + [2] * 128... etc. So I optimized it away.

If necessary that sequentialized group index could be brought back, I suppose, in which case it would still be possible to use the reordering approach on Y1 and Y2 individually. B1 and B2 would end up with grouped quantization parameters, only the groups would be irregular (unless you re-quantized for a given split) so there would be a little bit of extra work to do in the matmul kernel. Not too much though. Whether it's worth it to avoid the gather... idk.

At any rate, to implement something like this in ExLlama I'd probably just skip supporting act-order models. With extensive changes to the whole pipeline like that, I think there are probably more worthwhile ways to improve perplexity anyway, and I'm exploring a number of them for ExLlamaV2.

I'd love it if there were some comparative benchmarks to show the real-life impact of parallelism during generation. Depending on the model size, most of these operations end up being cheap individually, to the point that a lot of the optimizations I found myself doing in ExLlama were just eliminating Python code and doing the exact same thing in C++ because everything was happening too quickly on the CUDA side.

turboderp avatar Jul 05 '23 14:07 turboderp