cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[BUG] Fused GEMM example gives wrong result with some shapes

Open danthe3rd opened this issue 1 year ago • 7 comments

Describe the bug Fused GEMM example gives the wrong result for some values of problemSize1.K.

Steps/Code to reproduce bug Set the following problem sizes in examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_shmem.cu

cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_0(128*640, 48, 576);
cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_1(128*640, 256, 48);

Build & run

$ make 13_fused_two_gemms_f16_sm80_shmem && ./examples/13_two_tensor_op_fusion/13_fused_two_gemms_f16_sm80_shmem

Device: NVIDIA A100-SXM4-40GB
Arch: SM80
Test: gemm f16 shmem staging
Running Non-fused back-to-back FP16 TN GEMMs...
gemm 0 time 0.0858931 ms
gemm 1 time 0.0452915 ms
Non-fusion time 0.131185 ms
Pass
Running Fused back-to-back FP16 TN GEMMs with shared memory staging...
Fusion time 0.138086 ms
/scratch/dhaziza/xformers/third_party/cutlass/examples/13_two_tensor_op_fusion/b2b_gemm_run.h 686: CHECK_TRUE failed
Dumping results in error_B2bGemm_device_fused.txt

Environment details (please complete the following information): Bare-metal, A100

Additional context I've tracked down what happens in the fused multistage MMA. In the second matmul:

  • The Predicated iterator over C will first iterate over the partial tile of C (C[32:, :]) and then the first tile (C[:32, :])
  • The iterator over shared memory will iterate in the right order (eg (A@B)[:32,:], then (A@B)[32:,:])

This means that we calculate (A@B)[0:32] @ C[32:64] + (A@B)[32:64] @ C[0:32] instead of (A@B)[0:32] @ C[0:32] + (A@B)[32:64] @ C[32:64] and that's why we get the wrong result.

danthe3rd avatar Jul 14 '22 12:07 danthe3rd

@jwang323

hwu36 avatar Jul 14 '22 15:07 hwu36

Now, we assume the 2nd gemm problem size k is multiple of the threadblock tile size k. We can fix it pretty quickly. Before that, you can first use the problem size fitting the assumption.

hwu36 avatar Jul 14 '22 21:07 hwu36

Yes it works when K2 < 32 or K2 % 32 == 0 indeed.

We can fix it pretty quickly

Oh cool! That would be great :)

danthe3rd avatar Jul 14 '22 22:07 danthe3rd

Just wanted to follow-up on this quickly - do you have an ETA on when this could be available? Or maybe you can give me a few pointers to fix it (the proper way) and I would be happy to contribute a PR.

I see a few options: (1) Either change the shared-memory iterator (WarpIteratorA1) to follow the same pattern as the TileIteratorB in case it's a partial one (2) Change how TileIteratorB1 iterates (3) Change how we write to the shared memory with smem_iterator_D0_ so that the WarpIteratorA1 reads the partial tile first

But I'm not exactly sure what is the best way to proceed. Thanks a lot :)

danthe3rd avatar Jul 21 '22 19:07 danthe3rd

I'm working on a solution to add residual support for warp-tile iterator (either WarpIteratorA1 for shmem-resident fusion or FragmentIteratorA1 for RF-resident fusion). The vector iterator for bias/scaling is also required to change to support residual.

However, please note that this doesn't change cutlass's requirement that THREAD_BLOCK_TILE_N % 32 == 0 due to shared memory loading patterns. Therefore there will still be threadblock tile quantization for program_size_0_n=48, leading to possible performance loss.

ETA is at the end of next week.

jwang323 avatar Jul 21 '22 19:07 jwang323

please note that this doesn't change cutlass's requirement that THREAD_BLOCK_TILE_N % 32 == 0 due to shared memory loading patterns

Yes that's totally understandable. n=48 was just an example btw :) Thanks for the update!

danthe3rd avatar Jul 22 '22 06:07 danthe3rd

@jwang323 @danthe3rd are we able to close this issue?

mnicely avatar Aug 10 '22 09:08 mnicely

@hwu36 please comment on the availability of this patch.

jwang323 avatar Aug 10 '22 18:08 jwang323

The corresponding PR: https://github.com/NVIDIA/cutlass/pull/590

danthe3rd avatar Aug 10 '22 20:08 danthe3rd