composable_kernel icon indicating copy to clipboard operation
composable_kernel copied to clipboard

Enhance PartitionedBlockwiseReduction interface to allow more diverse reduction use cases

Open rosenrodt opened this issue 3 years ago • 1 comments

In fused attention kernel implementation, we were met with a "M0_K_M1 reduce K" problem that the original PartitionedBlockwiseReduction does not quite capture. To that end we had introduced an ad-hoc v2 version that allow user-defined thread cluster descriptor that chains 1D thread id to M0_K_M1 then to M_K, so it in effect reduces the middle dimension K among M0_K_M1.

This approach is cumbersome to the user of PartitionedBlockwiseReduction because he/she is responsible for constructing the tensor adaptor that chains the thread_id -> M0_K_M1 -> M_K transformation (fused attention kernel L657-L669). The aim is to enhance the interface so that it can process arbitrary reduction dimension with arbitrary thread cluster.

After brief discussion with @qianfengz the new interface should contain template args Sequence<> ThreadCluster, Sequence<> ReduceOrder, index_t ReduceLastNDim. With the last N dimension in ReduceOrder to be the dimensions to be reduced.

For example the "M0_K_M1 reduce K" problem can be captured by ThreadCluster = Sequence<M0, K, M1>, ReduceOrder = Sequence<0, 2, 1>, ReduceLastNDim = 1. "M0_K_M1 reduce M1" will be the same except that ReduceOrder = Sequence<0, 1, 2>

rosenrodt avatar Aug 16 '22 11:08 rosenrodt

The above is the description of the enhanced interface. To ease the implementation of this more generic interface, somehow a mapping function need be provided, which maps the local thread id to the 2D id of (thread_cluster_m_id, thread_cluster_k_id) for threads in the block to exchange data using the 2D LDS buffer. The mapping is like the follows:

template <typename ThreadClusterLengths, typename ThreadClusterArrangeOrder, index_t ReduceLastNDim>
Tuple<index_t, index_t>  get_thread_m_k_id(index_t thread_local_1d_id)
{
   // xxxx
}

qianfengz avatar Aug 16 '22 12:08 qianfengz