cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[QST] Permutation layout for contiguous stores

Open capybara-club opened this issue 10 months ago • 1 comments

I wrote out this permutation for a TN TF32 16x8x8. I was trying to get the threads to be contiguous when writing N major. I have the following TiledMMA with a cta size of (128,128,32):

    TiledMMA mma = 
        make_tiled_mma(
            SM80_16x8x8_F32TF32TF32F32_TN{},
            Layout<
                Shape<_4,_1>,
                Stride<_1,_4>
            >{}
            , 
            Tile<
                Layout<
                    Shape<_16>,
                    Stride<_1>
                >,
                Layout<
                    Shape<_2,_16>,
                    Stride<_1,_8>
                >,
                _8
            >{}
        );

Image

Besides working out the swizzle for B (I can see visually what it should do), am I interpreting this diagram correctly that with the right warp shfl_sync I should be able to write 128B contiguous as 16B per thread stores (again, in N major)? Am I paying anything in a non-obvious way with this permutation that I might be missing?

Edit: For example, does this break any of the cp.async or ldmatrix assumptions?

Thanks!

capybara-club avatar Feb 22 '25 15:02 capybara-club

@Junkai-Wu

hwu36 avatar Mar 04 '25 03:03 hwu36

As stated in https://github.com/NVIDIA/cutlass/issues/2140, you can check the permutation pattern for MN major or K major with different continuous bits here: https://github.com/NVIDIA/cutlass/blob/main/include/cute/atom/mma_traits_sm90_gmma.hpp#L74-L84 For the detailed swizzle pattern, you can check the implementation for struct Swizzle in cute.

Junkai-Wu avatar Mar 14 '25 09:03 Junkai-Wu

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] avatar Apr 13 '25 10:04 github-actions[bot]

This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.

github-actions[bot] avatar Jul 12 '25 10:07 github-actions[bot]