transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Does Qwen_2_5_VL support variable length attention computation?

Open yingtongxiong opened this issue 7 months ago • 8 comments

Feature request

Qwen_2_5_VL support variable length attention computation

Motivation

Hello, I try to run qwen25_vl with packing samples, however, I found that it seems this function only passes the attention_mask, not the position_ids in https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L908. So I pass the position_ids to this function and met the illegal memory access. Finally, I found that the position_ids has been expanded 3 times in dim 0, so how can I use the position_ids, what if I want to use varlen flash attention? Would anyone be able to help me with this?

Your contribution

no

yingtongxiong avatar May 08 '25 02:05 yingtongxiong

@yingtongxiong Qwen VL position ids are different from simple LLMs, so simply passing position_ids tp FA2 for packing will not solve the issue. Probably we'll need to pass different set of position_ids or infer it from 3D ids. I will take a look at it

zucchini-nlp avatar May 08 '25 12:05 zucchini-nlp

@zucchini-nlp thank you very much. I see in verl, it passes position_ids[0] to flash attention. I am not sure it is correct.

yingtongxiong avatar May 09 '25 03:05 yingtongxiong

Hello @zucchini-nlp and team,

I've been reviewing this feature request (#38007) and I find the challenge of implementing variable-length attention for the Qwen_2_5_VL model very interesting

Before I start exploring a potential solution, could you please confirm that no one else is actively working on this? I would be happy to take this on. If you have any initial guidance or specific requirements for the implementation (e.g., preferred attention backend like sdpa or flash_attention_2), that would also be very helpful.

I look forward to contributing.

Pankajku-mar avatar Jun 10 '25 18:06 Pankajku-mar

@Pankajku-mar No, afaik no-one is working on it so feel free to contribute 🤗

I can't suggest where to start exploring since I didn't look at the issue yet. Probably we will just need to pass another arg for 2D position ids used in FA2, under a different name so it doesn't clash with the 3D positions for RoPE. After that you can check the correctness by running this test on Qwen-VL models

https://github.com/huggingface/transformers/blob/aa798b7ac9ff5018b3578eb927dc438671ab6a3e/tests/test_modeling_common.py#L4131-L4136

zucchini-nlp avatar Jun 11 '25 07:06 zucchini-nlp

ok fine thanks for this @zucchini-nlp

gspeter-max avatar Jun 12 '25 12:06 gspeter-max

import torch

def prepare_fa2_from_position_ids(query, key, value, position_ids):
    query = query.view(-1, query.size(-2), query.size(-1))
    key = key.contiguous().view(-1, key.size(-2), key.size(-1))
    value = value.contiguous().view(-1, value.size(-2), value.size(-1))
    
    # This is YOUR logic we are testing
    seqlens_in_batch = position_ids.sum(dim=-1, dtype=torch.int32)
    indices_q = torch.nonzero(position_ids.flatten(), as_tuple=False).flatten()
    cu_seq_lens = torch.nn.functional.pad(
        torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32),
        (1, 0)
    )
    max_seqlens_in_batch = seqlens_in_batch.max().item()

    return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_seqlens_in_batch, max_seqlens_in_batch))

def prepare_fa2_from_position_ids1(query, key, value, position_ids): 
    
    query = query.view(-1, query.size(-2), query.size(-1))
    key = key.contiguous().view(-1, key.size(-2), key.size(-1))
    value = value.contiguous().view(-1, value.size(-2), value.size(-1))
    position_ids = position_ids.flatten()
    indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)

    cu_seq_lens = torch.cat(
        (
            indices_q[position_ids == 0],
            torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
        )
    )

    max_length = position_ids.max() + 1
    return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))

# --- The Test Inputs ---

# Let's create a batch of 2 sentences.
# Sentence A has length 3.
# Sentence B has length 4.
# The total padded length is 4.
batch_size = 2
seq_len = 4
num_heads = 8
head_dim = 16

# 1. Create fake Q, K, V tensors. Their content doesn't matter, only their shape.
query = torch.randn(batch_size, seq_len, num_heads, head_dim)
key = torch.randn(batch_size, seq_len, num_heads, head_dim)
value = torch.randn(batch_size, seq_len, num_heads, head_dim)

position_ids_as_mask = torch.tensor([
    [1, 1, 1, 0],  # Sentence A: length 3, one padding token
    [1, 1, 1, 1]   # Sentence B: length 4, no padding
], dtype=torch.int32)

print("--- Testing your function with a FAKE position_ids that acts like a mask ---")
print("Input position_ids (behaving like a mask):\n", position_ids_as_mask)

# --- Run your function and get the outputs ---
(
    out_q, out_k, out_v, 
    out_indices_q, 
    (out_cu_seqlens, _), 
    (out_max_seqlens, _)
) = prepare_fa2_from_position_ids(query, key, value, position_ids_as_mask)

print("\n--- new FUNCTION'S OUTPUT  ---")
print(f"Calculated seqlens_in_batch: {out_cu_seqlens.diff()}")
print(f"Calculated cu_seq_lens: {out_cu_seqlens}")
print(f"Calculated max_seqlens_in_batch: {out_max_seqlens}")

(
    out_q1, out_k1, out_v1, 
    out_indices_q1, 
    (out_cu_seqlens1, _), 
    (out_max_seqlens1, _)
) = prepare_fa2_from_position_ids1(query, key, value, position_ids_as_mask)


# --- Print the results ---
print("\n--- odd FUNCTION'S OUTPUT  ---")
print(f"Calculated seqlens_in_batch: {out_cu_seqlens1.diff()}")
print(f"Calculated cu_seq_lens: {out_cu_seqlens1}")
print(f"Calculated max_seqlens_in_batch: {out_max_seqlens1}")

outputs :


Input position_ids (behaving like a mask):
 tensor([[1, 1, 1, 0],
        [1, 1, 1, 1]], dtype=torch.int32)

--- new FUNCTION'S OUTPUT  ---
Calculated seqlens_in_batch: tensor([3, 4], dtype=torch.int32)
Calculated cu_seq_lens: tensor([0, 3, 7], dtype=torch.int32)
Calculated max_seqlens_in_batch: 4

--- odd FUNCTION'S OUTPUT  ---
Calculated seqlens_in_batch: tensor([5], dtype=torch.int32)
Calculated cu_seq_lens: tensor([3, 8], dtype=torch.int32)
Calculated max_seqlens_in_batch: 2

true output 
 tensor([[1, 1, 1, 0],
        [1, 1, 1, 1]], dtype=torch.int32)

--- YOUR FUNCTION'S OUTPUT ---
Calculated seqlens_in_batch: tensor([3, 4], dtype=torch.int32)
Calculated cu_seq_lens: tensor([0, 3, 7], dtype=torch.int32)
Calculated max_seqlens_in_batch: 4

gspeter-max avatar Jun 14 '25 10:06 gspeter-max

Hi,@zucchini-nlp I found a bug in the function — it returns incorrect output. I’m creating a PR to fix it, but I don’t have access to the model weights to fully test it.

Could you help verify the fix or point me to the right contributors to tag?

Thanks!

gspeter-max avatar Jun 14 '25 10:06 gspeter-max

https://github.com/huggingface/transformers/blob/d5d007a1a0f0c11a726a54c8f00bd71825f84d02/src/transformers/modeling_flash_attention_utils.py#L206

if doc is correct

    position_ids (`torch.Tensor`):
        Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.

gspeter-max avatar Jun 14 '25 10:06 gspeter-max

@gspeter-max it is not an incorrect shape because Qwen is expected to have 3D position ids to apply RoPE to all 3 dimensions of an input image So yeah, maybe the doc for the forward of Qwen2VLForConditionalGeneration.forward and Qwen2Model.forward doesn't really say about the shape

zucchini-nlp avatar Jun 16 '25 06:06 zucchini-nlp

Yeah exactly, but I am mostly talking about the behaviour of this position_id The doc says this is like a attention mask but actually in the real world I think the meaning of position_id is slightly diff I think about position_embedding

And the function is wrong for computing these things

Can you guide me for this @zucchini-nlp

gspeter-max avatar Jun 16 '25 06:06 gspeter-max

Oh yeah, I didn't pay attention to the FA2 docs. Indeed the position ids are supposed to be simply the positions of each input, in a packed way. For example if we pack two sequences we can end up with positions as below, thus getting rid of attention masks. In this case we can see where each sequence begins and ends, also what positions each token is at

ids = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 6, 7]

So for Qwen the idea is to make it work when one wants to do packing with FA2. User should be able to pass packed positions and fallback to FA2 attn computation without any masks

zucchini-nlp avatar Jun 16 '25 07:06 zucchini-nlp

Closing as resolved by https://github.com/huggingface/transformers/pull/39447. Qwen2-VL now supports packing as long as the inputs are correctly handled.

We don't have a specific collator for qwen type model, feel free to take a look at forward pass with dummy packed inputs here

zucchini-nlp avatar Jul 23 '25 09:07 zucchini-nlp