flux
flux copied to clipboard
[BUG] Fuse Reduction on SM 90
TORCH_CHECK(
!fuse_reduction || input_dtype == at::ScalarType::Half,
"Fuse reduction only support float16 type on SM80 due to instruction limitation.");
It explicitly restricts fused reduction to float16, regardless of GPU architecture.
When I am using fuse_reduction = True with bfloat16 on NVIDIA H100s, it is giving me memory error. I disabled the check though.
@zheng-ningxin sm90 fuse_reduction broken?