candle
candle copied to clipboard
Improve performance for long sequence generation (kvconcat kernel).
The current implementation of key-value concatenation relies on Tensor::cat which uses ucopy kernel to compute every indice for all output tensor elements based on the input shape and output stride (see below) and it also requires double-transpose. This PR implements a kvconcat kernel to bypass the indice compute and therefore it can speed up the generation speed by around 10% - 15%, especially during long sentence generation.
__device__ unsigned int get_strided_index(
unsigned int idx,
const size_t num_dims,
const size_t *dims,
const size_t *strides
) {
unsigned int strided_i = 0;
for (unsigned int d = 0; d < num_dims; d++) {
unsigned int dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
unsigned strided_i = get_strided_index(i, num_dims, dims, strides);
TYPENAME x = inp ? inp[strided_i] : out[i];
out[i] = FUNC;
}
The current implementation, element-wise copy.
//chunk_l equals to chunk_r for dim2 concat
template <typename T>
__device__ __forceinline__ void kvconcat_dim2_kernel(T *ltensor, T* rtensor, T *out,
const size_t chunk_l, const size_t chunk_r, const size_t lstride, const size_t rstride) {
int thread_id = GetThreadIdx();
int out_stride = lstride + rstride;
int idx = thread_id / out_stride;
int j = thread_id % out_stride;
T* pLeft = ltensor + idx * lstride;
T* pRight = rtensor + idx * rstride;
T* pOut = out + idx * out_stride;
if (idx < chunk_l) {
if (j < lstride)
pOut[j] = pLeft[j];
else
pOut[j] = pRight[j - lstride];
}
}
Thanks for suggesting this. I agree that there is currently an issue with how long kv cache takes so I've made #1855 to actually remove the transpose bit when doing Tensor::cat, this way it will benefit all operations rather than just kv-cache stuff. It can probably still be improved a bit, e.g. avoiding the remaining modulo or using cudaMemcpy2D but it should already brings quite a few benefits without requiring any model change.
There is probably still room for a highly optimized kv-cache kernels, but I would suggest for this making an external crate to handle it and experiment with it, you can already find a bunch of these e.g. candle-rotary for fast rotary embeddings etc.
Thanks for suggesting this. I agree that there is currently an issue with how long kv cache takes so I've made #1855 to actually remove the transpose bit when doing
Tensor::cat, this way it will benefit all operations rather than just kv-cache stuff. It can probably still be improved a bit, e.g. avoiding the remaining modulo or usingcudaMemcpy2Dbut it should already brings quite a few benefits without requiring any model change. There is probably still room for a highly optimized kv-cache kernels, but I would suggest for this making an external crate to handle it and experiment with it, you can already find a bunch of these e.g. candle-rotary for fast rotary embeddings etc.
Have you tested the kvconcat approach vs #1855 in terms of performance improvements? I think the kvconcat approach is simple and more straightforward, and the changes seem very heavy in #1855. I'm not sure if copy2d kernel in #1855 can benefit operations other than Tensor::cat. The proposed kvconcat approach is very similar to candle_nn::ops::softmax_last_dim which is widely used in candle models.