ao icon indicating copy to clipboard operation
ao copied to clipboard

What kind of layers are optimized by torchao on a RTX 4090?

Open naiveen opened this issue 9 months ago • 7 comments

I am trying to quantize a model and I am running this on a 4090. Since many of the available quantization benchmarks are done on higher gpus, I am trying to establish a baseline perfromance gain I can expect from quantization.

I tried the tutorial at torchao_demo on a gpu and it worked great. My model has similar kind of transformer layers with q, k, v projections but I am not able to see the same kind of performance with a large chunk of aten::_copy() operations in profile log.

To debug, I wanted to benchmark on a single linear layer as the majority of modified layers seem to be of this type. But I am not able to see any performance gain in this experiment of mine. I would appreciate if I can get more context into the specific layers that gets optimized by torchao.

'''
    https://github.com/ethanshenley/PyTorch-Conference-Recipes/blob/main/torchao_demo.ipynb
'''
import gc
import psutil
import torch
import torch.nn as nn
import time

from torchao.quantization import quantize_, int8_weight_only,float8_weight_only


device = "cuda:0"
def get_memory_usage():
    return psutil.Process().memory_info().rss / 1024 / 1024  # in MB

def run_inference(model, inputs, num_runs=10):
    start_time = time.time()
    for i in range(num_runs):
        with torch.no_grad():
            outputs = model(inputs[i].squeeze())
    torch.cuda.synchronize(device)
    end_time = time.time()
    return (end_time - start_time) / num_runs

# Load model and tokenizer
bsz = 16
n_runs = 100
for sz in range(1024, 20480, 1024):
    print('====================================================')
    print(f"Running with linear layer of size {sz}...")
    model = nn.Linear(sz, sz).to(device)
    inputs = torch.randn(n_runs, bsz, sz).to(device)

    print("\nRunning baseline model...")
    baseline_memory = get_memory_usage()
    baseline_time = run_inference(model, inputs, n_runs)
    print(f"Baseline - Time: {baseline_time:.4f}s, Memory: {baseline_memory:.2f}MB")


    print("\nRunning int8 weight-only quantized model...")
    model_int8 = nn.Linear(sz, sz).to(device)
    quantize_(model_int8, int8_weight_only())
    int8_memory = get_memory_usage()
    int8_time = run_inference(model_int8, inputs, n_runs)
    print(f"Int8 Weight-Only - Time: {int8_time:.4f}s, Memory: {int8_memory:.2f}MB")

    print("\nRunning fp8 weight-only quantized model...")
    model_fp8 = nn.Linear(sz, sz).to(device)
    quantize_(model_fp8, float8_weight_only())  
    fp8_memory = get_memory_usage()
    fp8_time = run_inference(model, inputs, n_runs)
    print(f"fp8 Weight-Only  - Time: {fp8_time:.4f}s, Memory: {fp8_memory:.2f}MB")


    print("\nPerformance Improvements:")
    print(f"Int8 weight-only speedup: {baseline_time / int8_time:.2f}x")
    print(f"Int8 weight-only memory reduction: {baseline_memory / int8_memory:.2f}x")
    print(f"fp8 weight-only speedup: {baseline_time / fp8_time:.2f}x")
    print(f"fp8 weight-only memory reduction: {baseline_memory / fp8_memory:.2f}x")

    del model, model_int8, model_fp8, inputs
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.synchronize(device)

naiveen avatar Mar 01 '25 00:03 naiveen

torchao quantizes Linear layers. However depending on batch-size and layer shape you may see different levels of performance improvements for different techniques. Eg. weight-only works best for bs=1 while dynamic quant is preferred for bs=n scenarios.

Most of our benchmarks are run on A100 or H100. But you can try the gemlite kernels, details https://github.com/pytorch/ao/tree/main/torchao/quantization#gemlite-triton which are expected to be optimized for 4090. cc @mobicham

supriyar avatar Mar 01 '25 03:03 supriyar

Thank you @supriyar for pointing me to the gemlite kernels and thanks to @mobicham for the work on gemlite. I am able to optimize my model using both torchao int4 and the gemlite kernels.

This GPU cache flushing technique from the gemlite repo helped me benchmark small modules and to the overall model optimization. Without this, I find the benchmarks erroneous. I couldn't find much about why this was the case.

The performance gain from torchao_int4 to gemlite_int4 is small compared to the performance gains from a fp16 to torchao_int4. Oddly, I have seen varied results for different hyper parameters. But the results got more consistent as I was evaluating with linear layers of larger shapes. I am not entirely sure but it seems the unpacking and packing of smaller tensors is the bottleneck in this case. Any existing work that support smaller linear layers? Or is there a real bottleneck that makes it infeasible?

naiveen avatar Apr 15 '25 10:04 naiveen

@naiveen what are you trying to optimize exactly? In practice, you need torch.compile / cuda graphs end-2-end in your model to optimize inference, because there's overhead to launch the Triton kernels (especially the ones that have atomic addition). For example, you can test end-2-end Llama3 decoding speed-up with the torchao_int4 or gemlite backend with this script: https://github.com/mobiusml/hqq/blob/master/examples/hqq_lib_demo.py

mobicham avatar Apr 15 '25 10:04 mobicham

My bad, I have only evaluated on a batch size of 1. With greater batch sizes, I see the performance gain from using gemlite kernels.

naiveen avatar Apr 15 '25 12:04 naiveen

You should see the highest performance gain with batch-size =1 actually, 3-3.5x speed-up on the 4090 with 4-bit weights

mobicham avatar Apr 15 '25 12:04 mobicham

Hi, I ran the script at https://github.com/mobiusml/hqq/blob/master/examples/hqq_lib_demo.py to cross check for batch-size=1.

It seems the torchao backend is running faster at 156 tok/s vs gemlite backend at 116 tok/s. What am I missing? I am using torch==2.5.1

naveen@chapterhouse:~/hqq/examples$ git diff
diff --git a/examples/hqq_lib_demo.py b/examples/hqq_lib_demo.py
index 89f7ad5..fd40cf5 100644
--- a/examples/hqq_lib_demo.py
+++ b/examples/hqq_lib_demo.py
@@ -5,7 +5,7 @@
 ########################################################################
 import torch
 device        = 'cuda:0'
-backend       = "torchao_int4" #'torchao_int4' #"torchao_int4" (4-bit only) or "bitblas" (4-bit + 2-bit) or "gemlite" (8-bit, 4-bit, 2-bit, 1-bit)
+backend       = "gemlite" #'torchao_int4' #"torchao_int4" (4-bit only) or "bitblas" (4-bit + 2-bit) or "gemlite" (8-bit, 4-bit, 2-bit, 1-bit)
 compute_dtype = torch.bfloat16 if backend=="torchao_int4" else torch.float16
 cache_dir     = None
 model_id      = 'meta-llama/Meta-Llama-3-8B-Instruct'
naveen@chapterhouse:~/hqq/examples$ python hqq_lib_demo.py
Loading checkpoint shards: 100%|██████████████████████████████| 4/4 [00:01<00:00,  2.65it/s]
100%|██████████████████████████████████████████████████████| 99/99 [00:00<00:00, 348.47it/s]
100%|█████████████████████████████████████████████████████| 225/225 [00:30<00:00,  7.46it/s]
The 'batch_size' argument of StaticCache is deprecated and will be removed in v4.49. Use the more precisely named 'max_batch_size' argument instead.
The 'batch_size' attribute of StaticCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.
  0%|                                                                                                                                                                                   | 0/511 [00:00<?, ?it/s]/home/naveen/miniconda3/envs/test/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:167: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
100%|█████████████████████████████████████████████████████| 511/511 [00:48<00:00, 10.58it/s]
100%|████████████████████████████████████████████████████| 511/511 [00:04<00:00, 102.28it/s]
100%|████████████████████████████████████████████████████| 511/511 [00:04<00:00, 102.66it/s]
100%|████████████████████████████████████████████████████| 511/511 [00:04<00:00, 116.36it/s]
naveen@chapterhouse:~/hqq/examples$ git checkout .
Updated 1 path from the index
naveen@chapterhouse:~/hqq/examples$ git diff
naveen@chapterhouse:~/hqq/examples$ python hqq_lib_demo.py
Loading checkpoint shards: 100%|██████████████████████████████| 4/4 [00:00<00:00,  8.77it/s]
100%|██████████████████████████████████████████████████████| 99/99 [00:00<00:00, 260.26it/s]
100%|█████████████████████████████████████████████████████| 225/225 [00:29<00:00,  7.55it/s]
The 'batch_size' argument of StaticCache is deprecated and will be removed in v4.49. Use the more precisely named 'max_batch_size' argument instead.
The 'batch_size' attribute of StaticCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.
100%|█████████████████████████████████████████████████████| 511/511 [00:21<00:00, 24.30it/s]
100%|████████████████████████████████████████████████████| 511/511 [00:03<00:00, 144.85it/s]
100%|████████████████████████████████████████████████████| 511/511 [00:03<00:00, 146.59it/s]
100%|████████████████████████████████████████████████████| 511/511 [00:03<00:00, 156.26it/s]

naiveen avatar Apr 15 '25 12:04 naiveen

torchao_int4 is the fastest for batch_size=1 with group_size=64. Gemlite is good for higher batch-sizes. If you try with gpt-fast you should get the following ont the 4090 RTX:

  • torchao_int4 + gs=64: 182 tokens/sec
  • gemlite + gs=64: 176 tokens/sec
  • gemlite + gs=None: 189 tokens/sec

mobicham avatar Apr 15 '25 13:04 mobicham