Tensor shape mismatch in `get_rope_index` when handling truncated sequences in Qwen2-VL
Issue Description
When using VERL with Qwen2-VL model, we're experiencing a tensor shape mismatch error during training. The error occurs because the get_rope_index function in [qwen2_vl.py (line 123)].
[qwen2_vl.py (line 123)] generates position embeddings with a length that exceeds the model's maximum context window (4096 tokens), even though the input sequences are correctly truncated.
Error Message
RuntimeError: shape mismatch: value tensor of shape [3, 5040] cannot be broadcast to indexing result of shape [3, 4096]
This error is happening specifically at [qwen2_vl.py (line 123)]:
position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device)
Steps to Reproduce
- Configure VERL for training with Qwen2-VL model
- Set up a dataset with sequences longer than 4096 tokens
- Set
data.truncation='left'in the configuration - Start training with
run_ppo(config)
Cause
The root cause is that while the input sequences are correctly truncated to fit the model's context window (4096 tokens), the get_rope_index function in qwen2_vl.py still generates position embeddings for the untruncated sequence length. This creates a mismatch between:
llm_positionswith shape [3, 5040] (full sequence length)position_ids[..., attention_mask == 1]with shape [3, 4096] (truncated sequence length)
Full Stack Trace
RuntimeError: Caught RuntimeError in DataLoader worker process 3.
Original Traceback (most recent call last):
File "/home/user/miniconda3/envs/verl_yiming2/lib/python3.10/site-packages/torchdata/stateful_dataloader/worker.py", line 242, in *worker*loop
data = fetcher.fetch(index) # type: ignore[union-attr]
File "/home/user/miniconda3/envs/verl_yiming2/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/user/miniconda3/envs/verl_yiming2/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/data/yiming/verl/verl/verl/utils/dataset/rl_dataset.py", line 202, in **getitem**
position_ids = get_rope_index(
File "/data/yiming/verl/verl/verl/models/transformers/qwen2_vl.py", line 123, in get_rope_index
position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device)
RuntimeError: shape mismatch: value tensor of shape [3, 5040] cannot be broadcast to indexing result of shape [3, 4096]
Proposed Solution
Add a check in the get_rope_index function to ensure the generated position IDs don't exceed the model's context window. The fix should be implemented at [qwen2_vl.py (line 123)]:
# Before line 123:
if llm_positions.shape[1] > position_ids[..., attention_mask == 1].shape[1]:
max_length = position_ids[..., attention_mask == 1].shape[1]
llm_positions = llm_positions[:, -max_length:]
# Original line 123:
position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device)
Additional Context
This issue specifically affects multimodal models like Qwen2-VL that have complex position embedding generation logic. The current implementation doesn't account for truncated sequences when working with multimodal content.
Same problem here.
same problem here.
Do not need 4096 tokens. My error is: shape mismatch: value tensor of shape [3, 1919] cannot be broadcast to indexing result of shape [3, 512]****
any updates on this? what is the right fix?
any updates?