cutlass
cutlass copied to clipboard
Fix sgemm_sm80 example bug
Hey I was running sgemm_sm80.cu example and printing out the tensor layouts, the code triggered a segfault when I added more printing logs, and compute-sanitizer shows it was due to ldsm.
I found that:
tCrA : ptr[16b](0x7f54b9fff9e0) o ((_2,_2),_4,(_2,_2,_2)):((_1,_2),_4,(_16,_32,_64))
tXrA : ptr[16b](0x7f54b9fff9e0) o (((_4,_2),_1),_4,_4):(((_1,_16),_0),_4,_32)
tCrB : ptr[16b](0x7f54b9fffae0) o (_2,_8,(_2,_2,_2)):(_1,_2,(_16,_32,_64))
tXrB : ptr[16b](0x7f54b9fffae0) o (((_4,_2),_1),_4,_4):(((_1,_16),_0),_4,_32)
where the copy pipeline stage (4) number doesn't match MMA's (8)
https://github.com/NVIDIA/cutlass/blob/8206e7a0f57a9a057cdd2c3bb4899bd5154a82e1/examples/cute/tutorial/sgemm_sm80.cu#L233-L234
https://github.com/NVIDIA/cutlass/blob/8206e7a0f57a9a057cdd2c3bb4899bd5154a82e1/examples/cute/tutorial/sgemm_sm80.cu#L278-L280
k_block_next exceeds size<2>(tXrA) in tXrA(_, _, k_block_next).
After extending k-dim:
tCrA : ptr[16b](0x7f6471fff9c0) o ((_2,_2,_2),_4,(_2,_2)):((_1,_2,_4),_8,(_32,_64))
tXrA : ptr[16b](0x7f6471fff9c0) o ((_8,_1),_4,_4):((_1,_0),_8,_32)
tCrB : ptr[16b](0x7f6471fffac0) o ((_2,_2),_8,(_2,_2)):((_1,_2),_4,(_32,_64))
tXrB : ptr[16b](0x7f6471fffac0) o ((_8,_1),_4,_4):((_1,_0),_8,_32)