Liger-Kernel icon indicating copy to clipboard operation
Liger-Kernel copied to clipboard

[QnA]: Why `cos` and `sin` is expected to be `hdim`, not `hdim//2`?

Open tjtanaa opened this issue 4 months ago • 3 comments

I have a question regarding to the qwen2_vl MRope. From my understanding is as follows:

full_cos = torch.cat([cos_halfdim, cos_halfdim], dim=-1)
full_sin = torch.cat([sin_halfdim, sin_halfdim], dim=-1)

However from the unit tests and the code here,

https://github.com/linkedin/Liger-Kernel/blob/2845fe8363a6f40a265ec8102523fe4c0ded068e/src/liger_kernel/ops/qwen2vl_mrope.py#L7

the cos and sin are full_cos and full_sin instead. Is there a reason to not just pass half of the cos and sin to save memory movement?

tjtanaa avatar Aug 04 '25 09:08 tjtanaa

I believe it only reads first half of cos and sin in triton code. https://github.com/linkedin/Liger-Kernel/blob/2845fe8363a6f40a265ec8102523fe4c0ded068e/src/liger_kernel/ops/qwen2vl_mrope.py#L40-L41 https://github.com/linkedin/Liger-Kernel/blob/2845fe8363a6f40a265ec8102523fe4c0ded068e/src/liger_kernel/ops/qwen2vl_mrope.py#L52-L61

Tcc0403 avatar Aug 05 '25 18:08 Tcc0403

@Tcc0403 I am getting segfault if I pass in cos_halfdim and sin_halfdim where sin_halfdim

https://github.com/linkedin/Liger-Kernel/blob/2845fe8363a6f40a265ec8102523fe4c0ded068e/src/liger_kernel/ops/qwen2vl_mrope.py#L45-L50

When these are computed, it is using full hd to compute the pointer. So passing in cos_halfdim and sin_halfdim will cause segfault

tjtanaa avatar Aug 06 '25 12:08 tjtanaa

Oh yes, it will cause segfault if you only pass halfdim. I was saying the second half of the cos and sin aren't involved in memory movement between global and on-chip memory.

https://github.com/huggingface/transformers/blob/06f8004e5cd9d06cfbffc3f47afb6c2b43bcb3d2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L178

For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim].

I think liger rope was designed to match the api (assuming both are full shape), and only reading the first half in kernel code.

If you want to save some memory allocation by passing half cos and sin only, I think we can make an additional check and create another parameter such as cos_hd to avoid segfault on indexing for those tensors.

Tcc0403 avatar Aug 06 '25 16:08 Tcc0403