[Bug]: `triton_scaled_mm` never used on ROCm
Your current environment
The output of `python collect_env.py`
INFO 03-07 02:02:58 [__init__.py:207] Automatically detected platform rocm.
Collecting environment information...
PyTorch version: 2.5.1+rocm6.2
Is debug build: False
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: 6.2.41133-dd7f95766
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 10.5.0-1ubuntu1~22.04) 10.5.0
Clang version: Could not collect
CMake version: version 3.31.6
Libc version: glibc-2.35
Python version: 3.10.12 (main, Feb 4 2025, 14:57:36) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-131-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: AMD Instinct MI300X (gfx942:sramecc+:xnack-)
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: 6.2.41133
MIOpen runtime version: 3.2.0
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 104
On-line CPU(s) list: 0-103
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8470
CPU family: 6
Model: 143
Thread(s) per core: 1
Core(s) per socket: 52
Socket(s): 2
Stepping: 8
BogoMIPS: 4000.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
L1d cache: 4.9 MiB (104 instances)
L1i cache: 3.3 MiB (104 instances)
L2 cache: 208 MiB (104 instances)
L3 cache: 210 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-51
NUMA node1 CPU(s): 52-103
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] pytorch-triton-rocm==3.1.0
[pip3] pyzmq==26.2.1
[pip3] torch==2.5.1+rocm6.2
[pip3] torchaudio==2.5.1+rocm6.2
[pip3] torchvision==0.20.1+rocm6.2
[pip3] transformers==4.49.0
[conda] Could not collect
ROCM Version: 6.2.41134-65d174c3e
Neuron SDK Version: N/A
vLLM Version: 0.7.4.dev189+gae122b1cb
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
============================ ROCm System Management Interface ============================
================================ Weight between two GPUs =================================
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 0 15 15 15 15 15 15 15
GPU1 15 0 15 15 15 15 15 15
GPU2 15 15 0 15 15 15 15 15
GPU3 15 15 15 0 15 15 15 15
GPU4 15 15 15 15 0 15 15 15
GPU5 15 15 15 15 15 0 15 15
GPU6 15 15 15 15 15 15 0 15
GPU7 15 15 15 15 15 15 15 0
================================= Hops between two GPUs ==================================
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 0 1 1 1 1 1 1 1
GPU1 1 0 1 1 1 1 1 1
GPU2 1 1 0 1 1 1 1 1
GPU3 1 1 1 0 1 1 1 1
GPU4 1 1 1 1 0 1 1 1
GPU5 1 1 1 1 1 0 1 1
GPU6 1 1 1 1 1 1 0 1
GPU7 1 1 1 1 1 1 1 0
=============================== Link Type between two GPUs ===============================
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 0 XGMI XGMI XGMI XGMI XGMI XGMI XGMI
GPU1 XGMI 0 XGMI XGMI XGMI XGMI XGMI XGMI
GPU2 XGMI XGMI 0 XGMI XGMI XGMI XGMI XGMI
GPU3 XGMI XGMI XGMI 0 XGMI XGMI XGMI XGMI
GPU4 XGMI XGMI XGMI XGMI 0 XGMI XGMI XGMI
GPU5 XGMI XGMI XGMI XGMI XGMI 0 XGMI XGMI
GPU6 XGMI XGMI XGMI XGMI XGMI XGMI 0 XGMI
GPU7 XGMI XGMI XGMI XGMI XGMI XGMI XGMI 0
======================================= Numa Nodes =======================================
GPU[0] : (Topology) Numa Node: 0
GPU[0] : (Topology) Numa Affinity: 0
GPU[1] : (Topology) Numa Node: 0
GPU[1] : (Topology) Numa Affinity: 0
GPU[2] : (Topology) Numa Node: 0
GPU[2] : (Topology) Numa Affinity: 0
GPU[3] : (Topology) Numa Node: 0
GPU[3] : (Topology) Numa Affinity: 0
GPU[4] : (Topology) Numa Node: 1
GPU[4] : (Topology) Numa Affinity: 1
GPU[5] : (Topology) Numa Node: 1
GPU[5] : (Topology) Numa Affinity: 1
GPU[6] : (Topology) Numa Node: 1
GPU[6] : (Topology) Numa Affinity: 1
GPU[7] : (Topology) Numa Node: 1
GPU[7] : (Topology) Numa Affinity: 1
================================== End of ROCm SMI Log ===================================
LD_LIBRARY_PATH=/home/luka/git/vllm/.venv/lib/python3.10/site-packages/cv2/../../lib64:
NCCL_CUMEM_ENABLE=0
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY
🐛 Describe the bug
I found an issue with vLLM and block fp8 linear, where the ROCm platform is incorrectly using a cutlass execution path. Because the cutlass path is always disabled on ROCm, this kernel is never reached, and instead we fall back on either w8a8_block_fp8_matmul or torch.scaled_mm.
The way we got there:
- @rasmith added the triton kernel
triton_scaled_mmintocustom_ops.cutlass_scaled_mm(not the right place for it in my opinion) in 127c074 - @hongxiayang added DeepSeek support, using the cutlass path where cutlass_block_fp8_supported was True by default in c36ac98
- @LucasWilkinson fixed the default of
cutlass_block_fp8_supportedparam tocutlass_block_fp8_supported()which always returns False on ROCm in 76abd0c.
The effect of this is that triton_scaled_mm is currently never used.
I think the path forward is to move triton_scaled_mm out of the custom_ops.cutlass_scaled_mm. This should likely be done as part of larger refactoring of the FP8 code, including the new Fp8LinearOp added in #14390. Additionally, it would be good so (at least somewhat) unify the triton_scaled_mm with w8a8_block_fp8_matmul, which is the fallback for apply_block_fp8_linear.
Before submitting a new issue...
- [x] Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
This should likely be done as part of the refactoring mentioned in #11785 to use the ScaledMMKernel abstraction for FP8 kernels
This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!
Not stale
Can I take this up? Seems like a refactor and test for different conditions
Yes! I actually started some work on this you might find useful in #19434. Also take a look at #8913 to understand the broader goal of the refactor.
This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!
@ProExpertProg I am a beginner and this looks like a good first issue. Since this has not been picked, can i give a try at this?
@shivampr please go ahead! And let me know if you need any help
Hi @ProExpertProg,
I’m working on enabling TritonScaledMMLinearKernel on ROCm so vLLM can fall back to Triton on AMD GPUs when AITriton isn’t available.
Commit: https://github.com/shivampr/vllm/commit/9b1d63eeb4fd9340b8b1a3220a1320f1bcfd0433
What I changed
Enabled TritonScaledMMLinearKernel in the ROCm/CUDA dispatch path and wired up the Triton FP8 scaled matmul (triton_scaled_mm) with symmetric quantization so Triton is used as a fallback on AMD when AITriton/CUTLASS isn’t available.
Where I’m stuck (RunPod MI300X) while testing
- No vLLM ROCm Docker image available
- Source build keeps failing due to ROCm toolchain
- Torch matmul errors on gfx942 (
HIP error: invalid device function)
Questions
- Which Docker/base env do you recommend for ROCm vLLM dev?
- Any suggested workflow/CI to test custom ROCm kernels?
- Is a minimal integration test that confirms TritonScaledMM is selected on ROCm sufficient for validation?
Thanks for any guidance!
I usually develop bare meta but the rocm/vllm-dev container should work too. Can you create issues for the problems you're running into? You can also post in #sig-amd in the vLLM dev Slack?
Is a minimal integration test that confirms TritonScaledMM is selected on ROCm sufficient for validation?
Yeah I think if you can swing it. If it's too difficult/intrusive to test that's okay too.
Thanks for the quick reply!
~I tried getting access through AMD's website but GPUs weren't available. On RunPod there wasn't a pre-built vLLM ROCm image, and when I tried building from source I ran into installation issues. I'll open a ticket with the details as you suggested.~
~Since I couldn't test on actual hardware, I created a minimal integration test that mocks the ROCm platform and verifies TritonScaledMMLinearKernel gets selected correctly. It runs locally without needing real ROCm/GPU - just uses mocked dependencies to test the kernel selection logic. The test passes and confirms the fallback path works as expected.~
~I know it's not ideal compared to testing on real MI300X hardware, but I figured having some test coverage was better than none given the access constraints. Let me know if this works and i am looking forward for your feedback.~
Thanks for the earlier feedback!
I’ve now re-tested this PR on an actual AMD MI300X (ROCm 7.0) environment using the rocm/vllm-dev image. The new validation confirms that:
-
TritonScaledMMLinearKernel is correctly selected as the ROCm fallback
-
Triton scaled-mm kernel executes correctly and produces numerically stable results
-
Full vLLM inference via the OpenAI-compatible API runs successfully (Qwen2.5-0.5B-Instruct on ROCm)
The earlier mocked test remains for quick CI validation, but the implementation is now fully verified on real ROCm hardware.