composable_kernel icon indicating copy to clipboard operation
composable_kernel copied to clipboard

Unified attention CK Tile kernel

Open juuso-oskari opened this issue 1 month ago • 0 comments

Authors: @Chi-Chu319 @juuso-oskari

This PR implements a unified attention kernel written in CK Tile. It builds on top of the fmha_v3 (composable_kernel/example/ck_tile/01_fmha) with the pipeline largely remaining the same. This PR implements the following features introduced in Triton unified attention kernel:

reduced launch grid size at composable_kernel/example/ck_tile/01_unified_attention/unified_attention_impl.hpp

// args.num_tokens is the cumulative amount of tokens from all sequences
index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs;
dim3 grids            = Kernel::GridSize2D(args.num_kv_heads, total_num_q_blocks);
return launch_kernel(config, make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));

This is significantly less amount of programs launched compared to before grid=(num_seqs, max_seqlen // BLOCK_M, num_q_heads), which contained lots of empty programs (not all sequences are of length max_seqlen).

But since now the current sequence index cannot be taken from the program id, we need to do a binary search at the beginning of the kernel to find our sequence index (used to index sequence length; needed for determining innerloop length).

This is implemented at composable_kernel/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp:

// Binary search to find the sequence index for a given global index
CK_TILE_DEVICE static constexpr ck_tile::index_t
find_seq_idx(const int32_t* query_start_len_ptr,
                ck_tile::index_t target_idx,
                ck_tile::index_t num_seqs,
                ck_tile::index_t block_q,
                bool use_q_block_mode)
{
    ck_tile::index_t left = 0;
    ck_tile::index_t right = num_seqs;
    while (left < right)
    {
        ck_tile::index_t mid = (left + right) / 2;
        ck_tile::index_t val = query_start_len_ptr[mid];
        ck_tile::index_t mid_val = use_q_block_mode ? (val / block_q + mid) : val;
        
        if (mid_val <= target_idx)
        {
            left = mid + 1;
        }
        else
        {
            right = mid;
        }
    }
    return left - 1;
}
// usage inside the kernel
const auto [kv_head_idx, q_block_global_idx] = GetTileIndex(pid, kargs);
// grid size is (num_kv_heads, total_num_q_blocks)
// total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
// q.shape[0] is total number of query tokens across all batches
const index_t seq_idx = find_seq_idx(
    kargs.query_start_len_ptr, q_block_global_idx, kargs.num_seqs, BLOCK_Q, true
); // which seq am I

In order to process more query tokens per load in decode settings (where sequence length is small, often only 1), we group query tokens in the head dim. Up to num_queries_per_kv query tokens share the same key/value token (CQA-setting). The total number of grouped tokens for a tile load is BLOCK_M = BLOCK_Q * num_queries_per_kv.

We do this in the kernel implementation by transforming the tensor view for Q in dram:

const auto q_dram = [&]() {
    const auto q_dram_base = make_naive_tensor_view<address_space_enum::global>(
        q_ptr,
        make_tuple(cur_batch_query_len, num_queries_per_kv, HEAD_SIZE),
        make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1),
        number<UnifiedAttentionPipeline::kAlignmentQ>{},
        number<1>{});

    const auto q_dram_pad = pad_tensor_view( // aling seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED
        q_dram_base,
        // block sizes
        make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED),
        sequence<true, false, kPadHeadDimQ>{}
    ); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED)

    const auto q_dram_merged = transform_tensor_view(
                q_dram_pad,
                make_tuple(
                    make_merge_transform(
                        make_tuple(query_len_padded, num_queries_per_kv)
                    ),
                    make_pass_through_transform(HEAD_SIZE_PADDED)
                ),
                make_tuple(sequence<0, 1>{}, sequence<2>{}),
                make_tuple(sequence<0>{}, sequence<1>{})
    ); // flattens the first two dims, head idx is the fastest changing dim in the merged dim
    return q_dram_merged;
}();

This way, pipeline can remain untouched and use the BLOCK_M as its tile size.

TODO

  • Fix build error
static assertion failed due to requirement 'const ck_tile::sequence<0, 8, 8>{}[ck_tile::constant<1>{}] == 0': TensorLengths{}[number<1>{}] == 0

originated from

q_dram_window = make_tile_window_linear(

in `rocm/composable_kernel/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp``

  • Fix testing and benchmarking
  • Performance tuning (tile distribution, block sizes)

build

# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
../script/cmake-ck-dev.sh  ../ <arch>
make tile_example_unified_attention -j1

juuso-oskari avatar Oct 30 '25 12:10 juuso-oskari