candle icon indicating copy to clipboard operation
candle copied to clipboard

Improve performance for long sequence generation (kvconcat kernel).

Open guoqingbao opened this issue 1 year ago • 2 comments
trafficstars

The current implementation of key-value concatenation relies on Tensor::cat which uses ucopy kernel to compute every indice for all output tensor elements based on the input shape and output stride (see below) and it also requires double-transpose. This PR implements a kvconcat kernel to bypass the indice compute and therefore it can speed up the generation speed by around 10% - 15%, especially during long sentence generation.

__device__ unsigned int get_strided_index(
    unsigned int idx,
    const size_t num_dims,
    const size_t *dims,
    const size_t *strides
) {
    unsigned int strided_i = 0;
    for (unsigned int d = 0; d < num_dims; d++) {
        unsigned int dim_idx = num_dims - 1 - d;
        strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
        idx /= dims[dim_idx];
    }
    return strided_i;
}

for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { 
       unsigned strided_i = get_strided_index(i, num_dims, dims, strides); 
        TYPENAME x = inp ? inp[strided_i] : out[i]; 
         out[i] = FUNC; 
 } 

The current implementation, element-wise copy.

 //chunk_l equals to chunk_r for dim2 concat
template <typename T>
__device__ __forceinline__ void kvconcat_dim2_kernel(T *ltensor, T* rtensor, T *out,
    const size_t chunk_l, const size_t chunk_r, const size_t lstride, const size_t rstride) {
    int thread_id = GetThreadIdx();
    int out_stride = lstride + rstride;
    int idx = thread_id / out_stride;
    int j = thread_id % out_stride;
    T* pLeft = ltensor + idx * lstride;
    T* pRight = rtensor + idx * rstride;
    T* pOut = out + idx * out_stride;
    if (idx < chunk_l) {
        if (j < lstride)
            pOut[j] = pLeft[j];
        else
            pOut[j] = pRight[j - lstride];
    }
}

guoqingbao avatar Mar 14 '24 11:03 guoqingbao

Thanks for suggesting this. I agree that there is currently an issue with how long kv cache takes so I've made #1855 to actually remove the transpose bit when doing Tensor::cat, this way it will benefit all operations rather than just kv-cache stuff. It can probably still be improved a bit, e.g. avoiding the remaining modulo or using cudaMemcpy2D but it should already brings quite a few benefits without requiring any model change. There is probably still room for a highly optimized kv-cache kernels, but I would suggest for this making an external crate to handle it and experiment with it, you can already find a bunch of these e.g. candle-rotary for fast rotary embeddings etc.

LaurentMazare avatar Mar 17 '24 07:03 LaurentMazare

Thanks for suggesting this. I agree that there is currently an issue with how long kv cache takes so I've made #1855 to actually remove the transpose bit when doing Tensor::cat, this way it will benefit all operations rather than just kv-cache stuff. It can probably still be improved a bit, e.g. avoiding the remaining modulo or using cudaMemcpy2D but it should already brings quite a few benefits without requiring any model change. There is probably still room for a highly optimized kv-cache kernels, but I would suggest for this making an external crate to handle it and experiment with it, you can already find a bunch of these e.g. candle-rotary for fast rotary embeddings etc.

Have you tested the kvconcat approach vs #1855 in terms of performance improvements? I think the kvconcat approach is simple and more straightforward, and the changes seem very heavy in #1855. I'm not sure if copy2d kernel in #1855 can benefit operations other than Tensor::cat. The proposed kvconcat approach is very similar to candle_nn::ops::softmax_last_dim which is widely used in candle models.

guoqingbao avatar Mar 18 '24 03:03 guoqingbao