cutlass
cutlass copied to clipboard
[QST] What is Sw<3, 3, 3> in print?
What is your question?
auto tensor_2d = make_tensor(tensor_3d.data(), make_shape(64, 256));
printf("tensor_2d\n");
print(tensor_2d);
printf("\n");
print(tRS_sD);
printf("\n");
print(bSG_sD);
printf("\n");
print(gD_epi);
printf("\n");
tensor_3d
smem_ptr[16b](0x7f8f00000c00) o Sw<3,3,3> o _0 o (((_64,_256),_2)):(((_1,_64),_16384))
tensor_2d
smem_ptr[16b](0x7f8f00000c00) o (64,256):(_1,64)
smem_ptr[16b](0x7f8f00001380) o ((_8,_1),_1,_16,_1):((_1,_0),_0,_1024,_0)
smem_ptr[16b](0x7f8f00000c00) o Sw<3,3,3> o _0 o (((_64,_256),_2),_1,_1,_1):(((_1,_64),_16384),_0,_0,_0)
I noticed this "Sw<3,3,3>". If I peel the tensor out twice:
auto tensor_3d = bSG_sD(_, 0, 0, 0); and then auto tensor_2d = size<0>(tensor_3d); print(tensor_2d);
And I will get a value, like _32768, but not a tensor. Why?
The question can also be phrased as: What type is obtained through an extraction operation like (_, 0, 0, 0)? What types can the extraction operation work on?
After using make_tensor, what type is obtained? What types can make_tensor work on?