Why change the order of make_block_ptr when V.dtype.element_ty == tl.float8e5?
In the fused attention tutorial, there is this line
v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)
I can't quite figure out why the order depends on the element type. The tutorial didn't give an example input of type float8e5.
According to this blog (https://mengyibai.com/p/order-in-triton-make-block-ptr/), the order is only to help the compiler be more efficient, and is equal to
np.argsort(strides)
The blog seems to have a typo np.argsort(-strides).
Hi, I still can't get it. Why changing order can be more efficient?
The v matrix is used as the b argument to a wgmma instruction. wgmma allows fp16 inputs to be in either row-major or column major format, but for FP8 types the a matrix must be row-major and b must be column major.