Liger-Kernel
Liger-Kernel copied to clipboard
[QnA]: Why `cos` and `sin` is expected to be `hdim`, not `hdim//2`?
I have a question regarding to the qwen2_vl MRope. From my understanding is as follows:
full_cos = torch.cat([cos_halfdim, cos_halfdim], dim=-1)
full_sin = torch.cat([sin_halfdim, sin_halfdim], dim=-1)
However from the unit tests and the code here,
https://github.com/linkedin/Liger-Kernel/blob/2845fe8363a6f40a265ec8102523fe4c0ded068e/src/liger_kernel/ops/qwen2vl_mrope.py#L7
the cos and sin are full_cos and full_sin instead. Is there a reason to not just pass half of the cos and sin to save memory movement?
I believe it only reads first half of cos and sin in triton code. https://github.com/linkedin/Liger-Kernel/blob/2845fe8363a6f40a265ec8102523fe4c0ded068e/src/liger_kernel/ops/qwen2vl_mrope.py#L40-L41 https://github.com/linkedin/Liger-Kernel/blob/2845fe8363a6f40a265ec8102523fe4c0ded068e/src/liger_kernel/ops/qwen2vl_mrope.py#L52-L61
@Tcc0403 I am getting segfault if I pass in cos_halfdim and sin_halfdim where sin_halfdim
https://github.com/linkedin/Liger-Kernel/blob/2845fe8363a6f40a265ec8102523fe4c0ded068e/src/liger_kernel/ops/qwen2vl_mrope.py#L45-L50
When these are computed, it is using full hd to compute the pointer. So passing in cos_halfdim and sin_halfdim will cause segfault
Oh yes, it will cause segfault if you only pass halfdim. I was saying the second half of the cos and sin aren't involved in memory movement between global and on-chip memory.
https://github.com/huggingface/transformers/blob/06f8004e5cd9d06cfbffc3f47afb6c2b43bcb3d2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L178
For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim].
I think liger rope was designed to match the api (assuming both are full shape), and only reading the first half in kernel code.
If you want to save some memory allocation by passing half cos and sin only, I think we can make an additional check and create another parameter such as cos_hd to avoid segfault on indexing for those tensors.