Support kv8 (FP8) with torch_native attention backend
This patch fixes the issue where KV8 could not run when the attention backend was set to torch_native.
Motivation
Currently, when using --attention-backend torch_native, the --kv-cache-dtype fp8_e4m3 option is not supported, causing KV cache in FP8 to fail. This patch fixes the issue by ensuring that the query, key, and value tensors are cast to the same dtype before calling scaled_dot_product_attention.
Modifications
- Modified TorchNativeAttnBackend in torch_native_backend.py
- Added dtype casting for per_req_key and per_req_value to match per_req_query
- Ensures scaled_dot_product_attention works correctly with FP8 KV cache
Accuracy Tests
Tested in another PR #12612
Benchmarking and Profiling
Tested in another PR #12612
Checklist
- [x] Format your code according to the Format code with pre-commit.
- [ ] Add unit tests according to the Run and add unit tests.
- [ ] Update documentation according to Write documentations.
- [x] Provide accuracy and speed benchmark results according to Test the accuracy and Benchmark the speed.
[!WARNING] You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!
@JackChuang Please update this doc https://github.com/sgl-project/sglang/blob/main/docs/advanced_features/attention_backend.md?plain=1#L22
@Fridge003 Thanks for your review and approval. Could someone help merge this PR? Thanks~
@JackChuang Please fix conflict
@Fridge003 Could you please help merge this PR when you have free cycles? Thank you.
@JackChuang Do you have any example of accuracy benchmarking when enabling fp8 kv cache with torch native backend
@JackChuang Do you have any example of accuracy benchmarking when enabling fp8 kv cache with torch native backend
Didn't test accuracy but performance. I’ll run the accuracy tests and then update.
@Fridge003 Using native_torch with KV8, the precision is essentially lossless.
[KV16] Accuracy: 0.947 Invalid: 0.000 Latency: 2783.572 s Output throughput: 73.740 token/s
$ CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m sglang.launch_server --model-path /data02/models/Qwen3-235B-A22B --tp 4 --trust-remote-code --port 8041 --kv-cache-dtype fp8_e4m3 --disable-radix-cache --enable-torch-compile --attention-backend torch_native
$ python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 1319 --port 8041
[KV8] Accuracy: 0.949 Invalid: 0.001 Latency: 2984.291 s Output throughput: 68.772 token/s
$ CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -m sglang.launch_server --model-path /data02/models/Qwen3-235B-A22B --tp 4 --trust-remote-code --port 8042 --attention-backend torch_native --disable-radix-cache --enable-torch-compile
$ python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 1319 --port 8042
Exp on B200