[BUG] [CuTe DSL] Cold L2 leads to IMA in grouped_blockscaled_gemm
Which component has the problem?
CuTe DSL
Bug Report
Describe the bug
When running grouped_blockscaled_gemm.py with use_cold_l2 == True, I run into an IMA with warmup_iterations + iterations > 10 in many cases, including but not limited to:
- MXF4 and NVF4 with all test cases
- MXFP8 with >=4 groups None of the IMAs arises without cold l2 or with fewer than 10 iterations.
Steps/Code to reproduce bug
from grouped_blockscaled_gemm.py import run
run(
num_groups=4,
problem_sizes_mnkl=[(256, 256, 256, 1), (512, 512, 512, 1), (1024, 1024, 1024, 1), (2048, 2048, 2048, 1)],
ab_dtype=cutlass.Float8E4M3FN,
sf_dtype=cutlass.Float8E8M0FNU,
sf_vec_size=32,
c_dtype=cutlass.Float32,
a_major="m",
b_major="n",
c_major="m",
mma_tiler_mn=(128, 128),
cluster_shape_mn=(1, 1),
warmup_iterations=5,
iterations=10, # or any number of total iterations > 10
skip_ref_check=True,
use_cold_l2=True,
)
Environment details (please complete the following information):
- B200
This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.