Unified attention CK Tile kernel
Authors: @Chi-Chu319 @juuso-oskari
This PR implements a unified attention kernel written in CK Tile. It builds on top of the fmha_v3 (composable_kernel/example/ck_tile/01_fmha) with the pipeline largely remaining the same. This PR implements the following features introduced in Triton unified attention kernel:
reduced launch grid size at composable_kernel/example/ck_tile/01_unified_attention/unified_attention_impl.hpp
// args.num_tokens is the cumulative amount of tokens from all sequences
index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs;
dim3 grids = Kernel::GridSize2D(args.num_kv_heads, total_num_q_blocks);
return launch_kernel(config, make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
This is significantly less amount of programs launched compared to before grid=(num_seqs, max_seqlen // BLOCK_M, num_q_heads), which contained lots of empty programs (not all sequences are of length max_seqlen).
But since now the current sequence index cannot be taken from the program id, we need to do a binary search at the beginning of the kernel to find our sequence index (used to index sequence length; needed for determining innerloop length).
This is implemented at composable_kernel/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp:
// Binary search to find the sequence index for a given global index
CK_TILE_DEVICE static constexpr ck_tile::index_t
find_seq_idx(const int32_t* query_start_len_ptr,
ck_tile::index_t target_idx,
ck_tile::index_t num_seqs,
ck_tile::index_t block_q,
bool use_q_block_mode)
{
ck_tile::index_t left = 0;
ck_tile::index_t right = num_seqs;
while (left < right)
{
ck_tile::index_t mid = (left + right) / 2;
ck_tile::index_t val = query_start_len_ptr[mid];
ck_tile::index_t mid_val = use_q_block_mode ? (val / block_q + mid) : val;
if (mid_val <= target_idx)
{
left = mid + 1;
}
else
{
right = mid;
}
}
return left - 1;
}
// usage inside the kernel
const auto [kv_head_idx, q_block_global_idx] = GetTileIndex(pid, kargs);
// grid size is (num_kv_heads, total_num_q_blocks)
// total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
// q.shape[0] is total number of query tokens across all batches
const index_t seq_idx = find_seq_idx(
kargs.query_start_len_ptr, q_block_global_idx, kargs.num_seqs, BLOCK_Q, true
); // which seq am I
In order to process more query tokens per load in decode settings (where sequence length is small, often only 1), we group query tokens in the head dim. Up to num_queries_per_kv query tokens share the same key/value token (CQA-setting). The total number of grouped tokens for a tile load is BLOCK_M = BLOCK_Q * num_queries_per_kv.
We do this in the kernel implementation by transforming the tensor view for Q in dram:
const auto q_dram = [&]() {
const auto q_dram_base = make_naive_tensor_view<address_space_enum::global>(
q_ptr,
make_tuple(cur_batch_query_len, num_queries_per_kv, HEAD_SIZE),
make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1),
number<UnifiedAttentionPipeline::kAlignmentQ>{},
number<1>{});
const auto q_dram_pad = pad_tensor_view( // aling seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED
q_dram_base,
// block sizes
make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED),
sequence<true, false, kPadHeadDimQ>{}
); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED)
const auto q_dram_merged = transform_tensor_view(
q_dram_pad,
make_tuple(
make_merge_transform(
make_tuple(query_len_padded, num_queries_per_kv)
),
make_pass_through_transform(HEAD_SIZE_PADDED)
),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{})
); // flattens the first two dims, head idx is the fastest changing dim in the merged dim
return q_dram_merged;
}();
This way, pipeline can remain untouched and use the BLOCK_M as its tile size.
TODO
- Fix build error
static assertion failed due to requirement 'const ck_tile::sequence<0, 8, 8>{}[ck_tile::constant<1>{}] == 0': TensorLengths{}[number<1>{}] == 0
originated from
q_dram_window = make_tile_window_linear(
in `rocm/composable_kernel/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp``
- Fix testing and benchmarking
- Performance tuning (tile distribution, block sizes)
build
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
../script/cmake-ck-dev.sh ../ <arch>
make tile_example_unified_attention -j1