composable_kernel icon indicating copy to clipboard operation
composable_kernel copied to clipboard

threadwise_tensor_slice_transfer_v5r1 issue

Open joye opened this issue 9 months ago • 3 comments

(https://github.com/ROCm/composable_kernel/blob/764164b488a9009842c0ce4b14aa74d49eec5e6a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp#L147C1-L157C20)

This part of code seems incorrect when implementing the space curve algorithm; For example, if ordered_src_access_idx = Sequence<1, 0, 0, 0>; ordered_src_access_lengths = Sequence<2, 2, 1, 1>; when i = 3, tmp's result is ordered_src_access_idx[0] * ordered_src_access_lengths[0] * ordered_src_access_lengths[1] * ordered_src_access_lengths[2] + ordered_src_access_idx[0] * ordered_src_access_lengths[1] * ordered_src_access_lengths[2] + ordered_src_access_idx[1] * ordered_src_access_lengths[2] + ordered_src_access_idx[2]; where ordered_src_access_idx[0] exists twice in the result.

joye avatar May 07 '24 09:05 joye

@joye Internal ticket has been created to investigate this issue. Thanks!

ppanchad-amd avatar Aug 27 '24 14:08 ppanchad-amd

Draft PR was created for this issue https://github.com/ROCm/composable_kernel/pull/1492

ozturkosu avatar Aug 28 '24 22:08 ozturkosu

Hello @joye , Thank You for reaching out to the Composable Kernel team. We looked into this internally. We think logically the implementation is correct although initializing tmp to zero may make it more readable. Is there a way to reproduce any errors that you are observing? If not, we would like to stick with the current implementation.

hsadasiv avatar Sep 05 '24 03:09 hsadasiv

Hey @joye, I will be closing this issue for now since there seem to be no more actionable item at the moment. Please feel free to ask follow up questions or re-open the issue. Thanks!

tcgu-amd avatar Sep 30 '24 13:09 tcgu-amd