cutlass
cutlass copied to clipboard
[BUG] Copy_Atom with DefaultCopy cause misaligned address
Describe the bug
As of b7508e337938137a699e486d8997646980acfc58, Copy_Atom<DefaultCopy, float> cause misaligned address.
Steps/Code to reproduce bug
#include <cute/tensor.hpp>
using namespace cute;
__global__ void kernel(int m, int k, float* a, int lda) {
const auto mA = make_tensor(make_gmem_ptr(a), make_layout(make_shape(m, k), LayoutLeft{})); // (m, k)
const auto cA = make_identity_tensor(make_shape(m, k));
constexpr auto CtaShape = make_shape(_128{}, _128{}, _8{});
const auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _);
const auto ctaA = local_tile(mA, CtaShape, cta_coord, make_step(_1{}, _, _1{}));
const auto stripe_gA = local_tile(ctaA, make_tile(_128{}, _8{}), make_coord(blockIdx.x, _));
constexpr int VecSize = 4; // NOTE: the VecSize is for later STS, not LDG
const auto tiled_copy = make_tiled_copy(
Copy_Atom<DefaultCopy, float>{},
make_layout(make_shape(Int<128 / VecSize>{}, _8{})),
make_layout(make_shape(Int<VecSize>{}))
);
auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
auto staging_a = make_fragment_like<float>(Int<VecSize>{});
copy(tiled_copy, thr_copy.partition_S(stripe_gA(_, _, _0{}, 0)), staging_a);
if(thread(255)) {
print_tensor(staging_a);
}
}
int main() {
int size = 129; // misaligned address iff size % 4 != 0
float* dev_buffer;
cudaMalloc(&dev_buffer, sizeof(float) * size * size);
cudaMemset(dev_buffer, 0, sizeof(float) * size * size);
kernel<<<dim3(1, 1), 256>>>(size, size, dev_buffer, size);
cudaDeviceSynchronize();
cudaError_t status = cudaGetLastError();
if (status != cudaSuccess) {
fprintf(stderr, "CUDA Error on %s:%d\n", __FILE__, __LINE__);
fprintf(stderr, "CUDA Error Code : %d\n Error String: %s\n", status, cudaGetErrorString(status));
return -1;
}
}
produces
CUDA Error Code : 716
Error String: misaligned address
Expected behavior
As DefaultCopy assume the data is aligned to the byte boundary, the code must not produce error.
using AutoVectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<8>;
using DefaultCopy = AutoVectorizingCopy;
Base on the line int size = 129; // misaligned address iff size % 4 != 0, if I change it to 128 or 132, etc, then there will be no problem.
Environment details (please complete the following information):
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Aug_15_22:02:13_PDT_2023
Cuda compilation tools, release 12.2, V12.2.140
Build cuda_12.2.r12.2/compiler.33191640_0
- RTX 4090
The alignment assumptions are for dynamic layouts only. The static layouts that you're passing to copy are being proven to be aligned and are being vectorized. The misalignment is coming from the pointer itself.
You can use UniversalCopy<float> instead of DefaultCopy to avoid all auto-vectorization, which appears to be what you want here, or use copy_vec<float>(src, dst) to avoid vectorization, or write the copy out as a for loop to avoid vectorization
for (int i = 0; i < size(src); ++i) {
dst(i) = src(i);
}
or pass in a tile of data from gmem with the dynamic stride in it to reflect the potential misalignment. CuTe cannot detect misaligned pointers and dynamically branch between vectorized and non-vectorized copy paths.
with
if(thread(255)) {
print(stripe_gA(_, _, _0{}, 0));
print(thr_copy.partition_S(stripe_gA(_, _, _0{}, 0)));
print_tensor(staging_a);
}
gmem_ptr[32b](0xb04c00000) o (_128,_8):(_1,1024) // <-- size = 1024
gmem_ptr[32b](0xb04c071f0) o ((_1,_4),_1,_1):((_0,_1),_0,_0)
ptr[32b](0x7fa274fffce0) o _4:_1:
0.00e+00
0.00e+00
0.00e+00
0.00e+00
So it is thr_copy.partition_S that has eliminated the dynamic stride.
This should be documented very carefully, tho.
Correct. I agree that ideally cute::copy needs to be more conservative, and all alignment assumptions need to be opt-in. We're currently reviewing copy and it's Atoms and would like to have a full redesign of how these dispatches and assumptions work, we'll certainly keep this in mind.
This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.
Closing due to inactivity