flashinfer icon indicating copy to clipboard operation
flashinfer copied to clipboard

ROPE problem in MLA decode kernel

Open lsw825 opened this issue 1 year ago • 2 comments

In flashinfer.decode.BatchDecodeMlaWithPagedKVCacheWrapper, "rope_theta" and "rope_scaling" are params with default value None. However, after reading triton interface code, I found that rope_theta will be set as 1e4 if it is None. Besides, it seems to be a llama-like linear rope scaling interpolation, rather than "yarn" in deepseek model.

I think it will be better if I can disable rope calculation in the MLA decode kernel, and just apply rope before the kernel, then store k_pe with rope in the kv cache.

Is there any misunderstanding in my previous description? If not, is there any easy way to disable rope in the current kernel?

Thanks a lot.

lsw825 avatar Jan 09 '25 12:01 lsw825

Thanks for your suggestions, it would be easy to remove pe from the implementation, will do that later. We are working on a faster version of MLA decoding kernels that uses tensor cores, would you mind leaving some suggestions on the user interface?

yzh119 avatar Jan 10 '25 00:01 yzh119

Sure. I'm very interested in MLA kernel. You can contact me via the email in my profile if needed.

If I have any other suggestions that are different to the current interface, I'll also propose them via git issues.

Thx:)

lsw825 avatar Jan 10 '25 02:01 lsw825