sglang icon indicating copy to clipboard operation
sglang copied to clipboard

[Bug] The performance of the FA3 attention backend on Hopper is not up to expect.

Open cao1zhg opened this issue 7 months ago • 13 comments

Checklist

  • [x] 1. I have searched related issues but cannot get the expected help.
  • [x] 2. The bug has not been fixed in the latest version.
  • [x] 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • [x] 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
  • [x] 5. Please use English, otherwise it will be closed.

Describe the bug

SGLang adopts FA3 as the default attention backend in its latest release. However, it shows that its performance falls short of expectations, particularly when compared to the FlashInfer attention backend. Performance benchmark was conducted using the Qwen3 models on the Hopper GPUs.

Reproduction

The benchmark script is:

modelscope download --model 'Qwen/Qwen3-235B-A22B' --local_dir '/cpfs01/user/caoyizhong.cyz/Qwen3-Plus++

cd /cpfs01/user/caoyizhong.cyz/0506
export INPUT_LENGTHS=1024
export OUTPUT_LENGTH=20480
export NUM_RUNS=2
export MAX_MODEL_LEN=40960
export BATCH_SIZE=128

type="bf16"
mkdir -p ./type_${type}
mkdir -p ./type_${type}/result
mkdir -p ./type_${type}/log

export MEM_FRACTION_STATIC=0.85

for name in "Plus++";
do
    export ATTENTION_BACKEND="flashinfer"
    export TP_SIZE=8
    export MODEL_PATH=/cpfs01/user/caoyizhong.cyz/Qwen3-Plus++
    export QUANTIZATION=None
    export OUTPUT_PATH=./type_${type}/result/result_${name}_tp${TP_SIZE}_${MEM_FRACTION_STATIC}_$(date +%Y%m%d_%H%M%S)_${ATTENTION_BACKEND}.json
    export LOG_FILE=./type_${type}/log/log_${name}_tp${TP_SIZE}_${MEM_FRACTION_STATIC}_$(date +%Y%m%d_%H%M%S)_${ATTENTION_BACKEND}.log
    python ./benchmark.py --input-lengths $INPUT_LENGTHS --output-length $OUTPUT_LENGTH --num-runs $NUM_RUNS --model-path $MODEL_PATH --max-model-len $MAX_MODEL_LEN --quantization $QUANTIZATION --tp-size $TP_SIZE  --output-path $OUTPUT_PATH --num-runs $NUM_RUNS --mem-fraction-static $MEM_FRACTION_STATIC --batch-size $BATCH_SIZE --attention-backend $ATTENTION_BACKEND 2>&1 | tee $LOG_FILE
done

for name in "Plus++";
do
    export ATTENTION_BACKEND="fa3"
    export TP_SIZE=8
    export MODEL_PATH=/cpfs01/user/caoyizhong.cyz/Qwen3-Plus++
    export QUANTIZATION=None
    export OUTPUT_PATH=./type_${type}/result/result_${name}_tp${TP_SIZE}_${MEM_FRACTION_STATIC}_$(date +%Y%m%d_%H%M%S)_${ATTENTION_BACKEND}.json
    export LOG_FILE=./type_${type}/log/log_${name}_tp${TP_SIZE}_${MEM_FRACTION_STATIC}_$(date +%Y%m%d_%H%M%S)_${ATTENTION_BACKEND}.log
    python ./benchmark.py --input-lengths $INPUT_LENGTHS --output-length $OUTPUT_LENGTH --num-runs $NUM_RUNS --model-path $MODEL_PATH --max-model-len $MAX_MODEL_LEN --quantization $QUANTIZATION --tp-size $TP_SIZE  --output-path $OUTPUT_PATH --num-runs $NUM_RUNS --mem-fraction-static $MEM_FRACTION_STATIC --batch-size $BATCH_SIZE --attention-backend $ATTENTION_BACKEND 2>&1 | tee $LOG_FILE
done

The benchmark.py is:

import time
import json
import argparse
from sglang.srt.entrypoints.engine import Engine
import os
import fcntl
import errno
import random
from contextlib import contextmanager

@contextmanager
def file_lock(file_path):
    """文件锁上下文管理器"""
    lock_file = file_path + '.lock'
    while True:
        try:
            fd = os.open(lock_file, os.O_CREAT | os.O_EXCL | os.O_RDWR)
            break
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise
            time.sleep(0.1)
    try:
        yield
    finally:
        try:
            os.unlink(lock_file)
            os.close(fd)
        except OSError:
            pass

def generate_random_list(length):
    return [random.randint(0, 10000) for _ in range(length)]

def generate_random_batch(batch_size, length):
    return [generate_random_list(length) for _ in range(batch_size)]


def generation_result(engine, input_lengths, output_length, num_runs, kwargs, output_path, batch_size):
    results = []

    print(f"开始测试生成速度,输入长度: {input_lengths} 字符,输出长度: {output_length} tokens")
    print(f"将运行 {num_runs} 次...")

    for input_length in input_lengths:
        sampling_params = {
            "temperature": 0.7,
            "top_p": 0.8,
            "top_k": 20,
            "n": 1,
            "repetition_penalty": 1.0,
            "presence_penalty": 0.0,
            "frequency_penalty": 0.0,
            "max_new_tokens": output_length,
            "ignore_eos": True,
        }
        print(sampling_params)
        print(f"\n测试输入长度: {input_length}")
        print("=" * 50)

        for _ in range(num_runs):
            input_batch = generate_random_batch(batch_size, input_length)
            start_time = time.time()
            outputs = engine.generate(input_ids=input_batch, sampling_params=sampling_params)
            end_time = time.time()
            
            total_tokens = (input_length + output_length) * batch_size
            total_time = (end_time - start_time)
        
            avg_speed = total_tokens / total_time
            result = [kwargs, avg_speed]
            results.append(result)

            print(f"\n第 {_} 次统计结果:")
            print(f"输入prompt token数: {input_length} * {batch_size}")
            print(f"平均生成速度: {avg_speed:.2f} tokens/s")
            print(f"总token数: {total_tokens}")
            print(f"总耗时: {total_time:.2f} 秒")
            print("=" * 50)

            with file_lock(output_path):
                with open(output_path, "a") as f:
                    f.write(str(result) + "\n")

def main():
    # 创建命令行参数解析器
    parser = argparse.ArgumentParser(description='测试模型生成速度')
    parser.add_argument('--input-lengths', type=int, nargs='+', 
                       default=[63488, 129024],
                       help='输入长度列表,用空格分隔多个数字')
    parser.add_argument('--output-length', type=int, default=2048,
                       help='输出长度(token数)')
    parser.add_argument('--num-runs', type=int, default=1,
                       help='每个长度测试的次数')
    parser.add_argument('--model-paths', type=str, nargs='+',
                       default=["./qwen-3-4B-awq"],
                       help='模型路径列表,用空格分隔多个路径')
    parser.add_argument('--quantization', type=str, default="awq_marlin")
    parser.add_argument('--max-model-len', type=int, default=32768)
    parser.add_argument('--mem-fraction-static', type=float, default=0.85)
    parser.add_argument('--output-path', type=str, default="sgl_benchmark_results.txt",
                       help='结果输出文件路径')
    parser.add_argument('--tp-size', type=int, default=1)
    parser.add_argument('--dtype', type=str, default="auto")
    parser.add_argument('--attention-backend', type=str, default="fa3")
    parser.add_argument('--batch-size', type=int, default=1)
    
    args = parser.parse_args()

    for model_path in args.model_paths:
        print(f"\n开始测试模型: {model_path}")
        print("=" * 50)

        # 初始化模型
        kwargs = {
            "model_path": model_path,
            "context_length": args.max_model_len,
            "mem_fraction_static": args.mem_fraction_static,
            "log_level": "INFO",
            "tp_size": args.tp_size,
            # "enable_mixed_chunk": True,
            "skip_tokenizer_init": True,
            "dtype": args.dtype,
            "attention_backend": args.attention_backend
        }

        if args.quantization != "None":
            kwargs["quantization"] = args.quantization

        print(kwargs)
        print("=" * 50)

        engine = Engine(**kwargs)

        generation_result(engine, args.input_lengths, args.output_length, args.num_runs, kwargs, args.output_path, args.batch_size)
    
    print("\n所有模型测试完成!")

if __name__ == "__main__":
    main()

Environment

sglang == 0.4.6.post2 torch == 2.6.0+cu124 cuda12.6 #define NCCL_VERSION_CODE 22203

cao1zhg avatar May 07 '25 02:05 cao1zhg

Environment

sglang image == 0.4.6.post2.cu124 model == Qwen3/Qwen3-235B-A22B FA3 attention backend Image

Reproduction

sglang running on the 8*H20 GPUs.

apiVersion: apps/v1
kind: Deployment
metadata:
  name: sglang
  labels:
    app: sglang
spec:
  selector:
    matchLabels:
      app: sglang
  replicas: 1
  template:
    metadata:
      labels:
        app: sglang
    spec:
      containers:
      - name: sglang
        image: lmsysorg/sglang:v0.4.6.post2-cu124
        command:
        - bash
        - -c
        - |
          set -x
          python3 -m sglang.launch_server \
            --host 0.0.0.0 \
            --port 50050 \
            --model-path /data/Qwen/Qwen3-235B-A22B \
            --served-model-name Qwen3-235B-A22B \
            --enable-torch-compile \
            --tp 8 \
            --reasoning-parser qwen3
        resources:
          limits:
            nvidia.com/gpu: "8"
        ports:
        - containerPort: 50050
        volumeMounts:
        - name: data
          mountPath: /data
        - name: shm
          mountPath: /dev/shm
      volumes:
      - name: data
        persistentVolumeClaim:
          claimName: models
      - name: shm
        emptyDir:
          medium: Memory
          sizeLimit: 64Gi
      restartPolicy: Always

using vllm benchmark scripts and sharegpt datasets.

python3 benchmarks/benchmark_serving.py \
  --backend openai-chat \
  --model /data/Qwen/Qwen3-235B-A22B \
  --served-model-name Qwen3-235B-A22B \
  --endpoint /v1/chat/completions \
  --dataset-name sharegpt \
  --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json \
  --num-prompts 2400 \
  --host xxx.xxx.xxx.xxx \
  --port 50050 \
  --request-rate 8 \
  --ignore-eos

Result

Image

coderwangke avatar May 07 '25 04:05 coderwangke

me too...

Environment

sglang image == 0.4.6.post2.cu124 model == Qwen3/Qwen3-235B-A22B FA3 attention backend Image

Result

sglang Mean TTFT 230.81 vllm Mean TTFT 117.20

Thanks for reporting, would u pls share the repro scripts?

hebiao064 avatar May 07 '25 04:05 hebiao064

Are you measuring only TTFT? 230 and 117 milliseconds - I'm sorry but using this metric at these ranges is pointless, a person will never notice this difference. You should be measuring not TTFT but overall throughput and token generation speed.

Swipe4057 avatar May 07 '25 06:05 Swipe4057

Are you measuring only TTFT? 230 and 117 milliseconds - I'm sorry but using this metric at these ranges is pointless, a person will never notice this difference. You should be measuring not TTFT but overall throughput and token generation speed.

not noly TTFT.

coderwangke avatar May 07 '25 08:05 coderwangke

You have different total numbers of generated tokens, but they should be the same. You are running the benchmark incorrectly. Fix the number of output tokens.

Swipe4057 avatar May 07 '25 08:05 Swipe4057

You have different total numbers of generated tokens, but they should be the same. You are running the benchmark incorrectly. Fix the number of output tokens.

Neither sglang nor vllm benchmarks have this setting.

coderwangke avatar May 07 '25 08:05 coderwangke

--random-output

python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 500 --random-input 4096 --random-output 2048

Swipe4057 avatar May 07 '25 09:05 Swipe4057

try --page-size 16, the default page size 1 is not a good one as i tested

I also found that the ttft of sglang is slower than vllm, but tpot is much faster, and the total latency of sglang is shorter than vllm

liuteng avatar May 07 '25 09:05 liuteng

@coderwangke your benchmark is totally wrong

Image

zhyncs avatar May 07 '25 18:05 zhyncs

It seems that H20 doesn't show significant optimization. Is it be that the developers only conducted sufficient testing on H200 or H100, without covering H20?

icewool avatar May 08 '25 07:05 icewool

@coderwangke your benchmark is totally wrong Image

I also noticed. guest 1: vllm benchmarks script's statistics are wrong.

sgl benchmarks will support chat completions?

coderwangke avatar May 08 '25 11:05 coderwangke

https://github.com/sgl-project/sglang/pull/6151

snippetzero avatar May 12 '25 08:05 snippetzero

try --page-size 16, the default page size 1 is not a good one as i tested

I also found that the ttft of sglang is slower than vllm, but tpot is much faster, and the total latency of sglang is shorter than vllm

Maybe --enable-mixed-chunk can be helpful.

TianQiLin666666 avatar May 12 '25 10:05 TianQiLin666666

This issue has been automatically closed due to inactivity. Please feel free to reopen it if needed.

github-actions[bot] avatar Jul 12 '25 00:07 github-actions[bot]