ROPE problem in MLA decode kernel
In flashinfer.decode.BatchDecodeMlaWithPagedKVCacheWrapper, "rope_theta" and "rope_scaling" are params with default value None. However, after reading triton interface code, I found that rope_theta will be set as 1e4 if it is None. Besides, it seems to be a llama-like linear rope scaling interpolation, rather than "yarn" in deepseek model.
I think it will be better if I can disable rope calculation in the MLA decode kernel, and just apply rope before the kernel, then store k_pe with rope in the kv cache.
Is there any misunderstanding in my previous description? If not, is there any easy way to disable rope in the current kernel?
Thanks a lot.
Thanks for your suggestions, it would be easy to remove pe from the implementation, will do that later. We are working on a faster version of MLA decoding kernels that uses tensor cores, would you mind leaving some suggestions on the user interface?
Sure. I'm very interested in MLA kernel. You can contact me via the email in my profile if needed.
If I have any other suggestions that are different to the current interface, I'll also propose them via git issues.
Thx:)