cutlass
cutlass copied to clipboard
[BUG] Fused GEMM example gives wrong result with some shapes
Describe the bug
Fused GEMM example gives the wrong result for some values of problemSize1.K
.
Steps/Code to reproduce bug
Set the following problem sizes in examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_shmem.cu
cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_0(128*640, 48, 576);
cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_1(128*640, 256, 48);
Build & run
$ make 13_fused_two_gemms_f16_sm80_shmem && ./examples/13_two_tensor_op_fusion/13_fused_two_gemms_f16_sm80_shmem
Device: NVIDIA A100-SXM4-40GB
Arch: SM80
Test: gemm f16 shmem staging
Running Non-fused back-to-back FP16 TN GEMMs...
gemm 0 time 0.0858931 ms
gemm 1 time 0.0452915 ms
Non-fusion time 0.131185 ms
Pass
Running Fused back-to-back FP16 TN GEMMs with shared memory staging...
Fusion time 0.138086 ms
/scratch/dhaziza/xformers/third_party/cutlass/examples/13_two_tensor_op_fusion/b2b_gemm_run.h 686: CHECK_TRUE failed
Dumping results in error_B2bGemm_device_fused.txt
Environment details (please complete the following information): Bare-metal, A100
Additional context I've tracked down what happens in the fused multistage MMA. In the second matmul:
- The Predicated iterator over C will first iterate over the partial tile of C (
C[32:, :]
) and then the first tile (C[:32, :]
) - The iterator over shared memory will iterate in the right order (eg
(A@B)[:32,:]
, then(A@B)[32:,:]
)
This means that we calculate (A@B)[0:32] @ C[32:64] + (A@B)[32:64] @ C[0:32]
instead of (A@B)[0:32] @ C[0:32] + (A@B)[32:64] @ C[32:64]
and that's why we get the wrong result.
@jwang323
Now, we assume the 2nd gemm problem size k is multiple of the threadblock tile size k. We can fix it pretty quickly. Before that, you can first use the problem size fitting the assumption.
Yes it works when K2 < 32 or K2 % 32 == 0
indeed.
We can fix it pretty quickly
Oh cool! That would be great :)
Just wanted to follow-up on this quickly - do you have an ETA on when this could be available? Or maybe you can give me a few pointers to fix it (the proper way) and I would be happy to contribute a PR.
I see a few options:
(1) Either change the shared-memory iterator (WarpIteratorA1
) to follow the same pattern as the TileIteratorB
in case it's a partial one
(2) Change how TileIteratorB1
iterates
(3) Change how we write to the shared memory with smem_iterator_D0_
so that the WarpIteratorA1
reads the partial tile first
But I'm not exactly sure what is the best way to proceed. Thanks a lot :)
I'm working on a solution to add residual support for warp-tile iterator (either WarpIteratorA1
for shmem-resident fusion or FragmentIteratorA1
for RF-resident fusion). The vector iterator for bias/scaling is also required to change to support residual.
However, please note that this doesn't change cutlass's requirement that THREAD_BLOCK_TILE_N % 32 == 0 due to shared memory loading patterns. Therefore there will still be threadblock tile quantization for program_size_0_n=48, leading to possible performance loss.
ETA is at the end of next week.
please note that this doesn't change cutlass's requirement that THREAD_BLOCK_TILE_N % 32 == 0 due to shared memory loading patterns
Yes that's totally understandable. n=48
was just an example btw :)
Thanks for the update!
@jwang323 @danthe3rd are we able to close this issue?
@hwu36 please comment on the availability of this patch.
The corresponding PR: https://github.com/NVIDIA/cutlass/pull/590