ao icon indicating copy to clipboard operation
ao copied to clipboard

`int4_weight_only` Slows Down `torch.nn.Linear` for Llama2 7B Shapes

Open mostafaelhoushi opened this issue 10 months ago • 13 comments

I have created a small script to benchmark int4 quantization on A100 GPUs, with inputs that have batch size 1 and seqlen 1.

When I test weigh shapes that exist in Llama2 7B, I actually get a slow down:

# input_dim, output_dim = 4096, 4096
Baseline:       0.023313920497894287 ms
Quantized:      0.08300095558166504 ms
# input_dim, output_dim = 4096, 11008
Baseline:       0.06082496166229248 ms
Quantized:      0.08460960388183594 ms
# input_dim, output_dim = 11008, 4096
Baseline:       0.059748477935791015 ms
Quantized:      0.09495231628417969 ms

When I use a really large shape that doesn't exist in Llama2 7B, I do get some speedup:

# input_dim, output_dim = 11008, 11008
Baseline:       0.14746272087097168 ms
Quantized:      0.09298111915588379 ms

This is strange because gpt-fast uses a similar int4 quantization and gets 2x speedup on Llama2 7B.

mostafaelhoushi avatar Jan 23 '25 16:01 mostafaelhoushi

cc @HDCharles as @jerryzh168 is OOO

vkuzo avatar Jan 23 '25 16:01 vkuzo

cc @mostafaelhoushi What hardware, pytorch and ao version are you using?

On my H100 on the nightlies, I see: (for 4096, 11008) which is a speedup

Baseline:       0.04709856033325195 ms
Quantized:      0.041566081047058105 ms

jcaip avatar Jan 23 '25 19:01 jcaip

@jcaip I am running on NVIDIA A100-SXM4-80GB and using the following libraries:

pytorch-triton==3.0.0+dedb7bdf33
torch==2.5.1+cu121

Let me draft another script that benchmarks a Hugging Face model. I am worried that with that speedup you have for a kernel on H100 nightly, you get no speedup for Llama2 7B.

mostafaelhoushi avatar Jan 23 '25 19:01 mostafaelhoushi

I would definitely recommend using the latest nightlies to test

jcaip avatar Jan 23 '25 20:01 jcaip

Thanks @jcaip So the script to benchmark a Llama2 7B is here.

I installed torch nightly and torchao nightly. When benchmarking on the same A100 a single linear layer `(4096, 11008), I get:

Baseline:       0.06053599834442139 ms
Quantized:      0.09977215766906739 ms

and when benchmarking the Llama2 7B model I get:

Baseline:       Model: 40.05844970703125 ms, MLP Layer: 0.18576160430908203 ms
Quantized:      Model: 60.4076806640625 ms, MLP Layer: 0.38929054260253904 ms

mostafaelhoushi avatar Jan 23 '25 20:01 mostafaelhoushi

I also see slowdowns on my A100, not sure of the exact cause. Maybe there were some changes to the int4 kernel in core? I also see you're running without compile, but I don't think that should make a difference for the single layer linear ...

jcaip avatar Jan 23 '25 21:01 jcaip

Just to add another datapoint. I am quite new to the library, I just came across yesterday while trying to quantize a large model for inference, but also noticed quite some slow down after quantization.

I tried the script that is shared above for benchmarking, and this are the numbers I get in H100 also.

# bs, input_dim, output_dim = 1, 4096, 11008
Baseline:       0.036133759021759033 ms
Quantized:      0.0388044810295105 ms

# bs, input_dim, output_dim = 16, 4096, 11008, 
Baseline:       0.03751264095306397 ms
Quantized:      0.05798175811767578 ms

joanPlepi avatar Jan 24 '25 13:01 joanPlepi

On H100 I am seeing On 2.6 release:

Baseline:       58.58537037037032 us
Quantized:      21.989801687764025 us

On Nightly

Baseline:       58.74621212121216 us
Quantized:      21.948585227273114 us

drisspg avatar Jan 25 '25 00:01 drisspg

Thanks @drisspg . You got those speedups without adding any torch.compile() statements? I am interested in getting that speedup you got in eager mode. Not sure how you got that ~3x speedup on H100 while @jcaip didn't.

mostafaelhoushi avatar Jan 25 '25 17:01 mostafaelhoushi

No torch.compile, only change to your script was using this timer func: https://github.com/drisspg/transformer_nuggets/blob/46127c65fa72c338fb600dd0373cb7fc36bd9613/transformer_nuggets/utils/benchmark.py#L55

which is the closest I have found to what NCU would report for kernel time. So ignoring any CPU overhead.

drisspg avatar Jan 25 '25 17:01 drisspg

Thanks @drisspg . Indeed when I tried that timing function, I got big speedups. It also lead to big speedups when I measured speedup on a whole model.

I want to verify, if it ignores any CPU overhead, does it mean it doesn't measure end-to-end speedup of execution? Will it be fair to use it to measure speedup for models rather than indiviudal kernels?

mostafaelhoushi avatar Jan 27 '25 01:01 mostafaelhoushi

@mostafaelhoushi That is really good question and like all great questions I kinda depends on what you care about. Often we are looking at changes to individual kernels so the above function is very helpful in that case. But if what you really care about is wall clock time I use https://github.com/drisspg/transformer_nuggets/blob/8b0a671b7b30cc7e186edd654f9c1565251b9b97/transformer_nuggets/utils/benchmark.py#L44

which should capture the cpu overhead and manually adds the cuda_syncs. Obviously the best thing the measure is the thing you actually care about but these act as proxies. I find that just looking at the pytorch traces is also really helpful

drisspg avatar Jan 27 '25 21:01 drisspg

I think it's likely due to L2 cache. Happens to everyone before when doing microbenchmarks 😅

https://github.com/pytorch/pytorch/blob/0144613e6ff6e018ca41085d1509dcceb80987f7/torch/_inductor/utils.py#L150-L158

(seems like inspired bytriton.testing.do_bench, or maybe it's the other way round)

gau-nernst avatar Feb 02 '25 13:02 gau-nernst