[BUG] Dual meaning of `max_position_embeddings`, computing both embedding shape & yarn scaling base
Describe the bug
When using MLA on a sequence length other than config.max_position_embeddings, a tensor shape mismatch error is thrown while applying the positional embeddings, stemming from this line.
Effectively, the seq_len dimension of the embedding matrix has to correspond to the actual sequence length being passed into the model, which is generally different from the old_context_len parameter. The old_context_len parameter is used in the computation of the YARN scaling, for example here. So changing this parameter in the config to match the context length at training time will cause the positional embeddings to be wrong.
To Reproduce
- Try a forward pass using MLA on a sequence whose length is different from
config.max_position_embeddings. - There will be a tensor shape mismatch error inside the application of the positional embeddings to q and k.
- Changing
config.max_position_embeddingswill cause a different embedding to be computed & produce incorrect output, but will make the error go away.
Expected behavior
The size of the embedding matrix adapts to the sequence length of the Q / K tensors or config.max_sequence_length instead of config.max_position_embeddings`.
Stack trace/logs Can provide if needed
Environment (please complete the following information):
- Megatron-LM commit ID:
aa719a0b0145481fb9212c577ee9a3f000fd16da+ internal patches - PyTorch version: 2.5.1
- CUDA version: 12.2
- NCCL version: 2.21.5
Proposed fix
I was able to patch this by replacing the above line with rotary_pos_emb = self.rotary_pos_emb(max_seq_len=q_len), though I do not believe this is a general solution for inference-time.
Additional context N/A
Hi, thanks for your issue, we were aware of this bug and have already come up with a fix for 0.11 release. It will further be integrated with other pos_emb functions in the next release
Marking as stale. No activity in 60 days.
@BoxiangW It seems that it didn't ship with 0.11.0, any update on this?
Hi @yzlnew, it should be fixed with https://github.com/NVIDIA/Megatron-LM/blob/00efe37a85194a521789778ae47299ce8c054dc0/megatron/core/transformer/multi_latent_attention.py#L363
@BoxiangW I'm afraid that's not relevant to this specific issue,
For YarnRotaryEmbedding, it is still initialized with original_max_position_embeddings=config.max_position_embeddings. But for DeepSeek-V3, it should be set to 4096 for whatever seq length it is trained on.
https://github.com/NVIDIA/Megatron-LM/blob/00efe37a85194a521789778ae47299ce8c054dc0/megatron/core/transformer/multi_latent_attention.py#L96-L108
I think you are talking about a different issue. So this self.config.max_position_embeddings is defined here in https://github.com/NVIDIA/Megatron-LM/blob/00efe37a85194a521789778ae47299ce8c054dc0/megatron/core/transformer/transformer_config.py#L1125
I think YaRN only need this original_max_position_embeddings for its computation
@BoxiangW Yes, maybe a config for original_max_position_embeddings, otherwise it will be overwritten by max_position_embeddings passed from arguments.
Hi @BoxiangW , I also observed this overwritten issue due to the duplicated name of max_position_embeddings and original_max_position_embeddings (named max_position_embeddings in MLATransformerConfig).
Thanks for the feedbacks. In this case, I can change the naming of https://github.com/NVIDIA/Megatron-LM/blob/00efe37a85194a521789778ae47299ce8c054dc0/megatron/core/transformer/transformer_config.py#L1125 into original_max_position_embeddings instead to avoid conflicts.
@yzlnew @lostkevin I am trying to understand the root cause of this issue a little bit more, could you share a simple reproduce code for this issue? Thanks!
@BoxiangW It's a little bit complicated and also confusing. I'll try my best.
First, let's walk through the original config from DeepSeek.
https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/config.json
Basically, we need
"rope_scaling": {
"beta_fast": 32,
"beta_slow": 1,
"factor": 40,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn"
},
"max_position_embeddings": 163840,
for YaRN computation.
For DeepSeek V3, the model is first heavily pretrained on 4096 seq length, and gradually trained on longer sequences for long context extension. The YaRN parameters are set to $s=40, \alpha=1, \beta=32$ for the two extension stages, which make it a model with ideal context length at max_position_embeddings=4,096*40=163,840.
As for YaRN, the original_max_position_embeddings is used to find the correction range, which shouldn't be set to max_position_embeddings passed during training.
And for the implementation in transformers, the factor is also updated using this value. But I don't know if it is necessary.
https://github.com/huggingface/transformers/blob/e94a4807df45ec2967bb7f88c3d008f40fc9d550/src/transformers/modeling_rope_utils.py#L260-L267
As for best practice when training DeepSeek, I think it is safe to set rope_type=rope when training with 4k length. For larger sequences, set $s=40, \alpha=1, \beta=32$ with original_max_position_embeddings=4096.
Any progress here?Jump from https://github.com/alibaba/Pai-Megatron-Patch/tree/main/megatron_patch/fixes/yarn_args
Hi, this issue should be fixed by 25.07 container
https://github.com/NVIDIA/Megatron-LM/blob/76144fe1106e4fb0e69aa75b7a6ab66e71e8f37f/megatron/core/transformer/transformer_config.py#L1288 is the fix. We will deprecate the old max_position_embeddings config in next release 25.09
https://github.com/NVIDIA/Megatron-LM/commit/cb6ab12c49abfb767d82e7b07b57f16163e5d2e2 is merged into main and MCore 0.14.0 release (NeMo 24.09 container) for completely deprecating max_position_embeddings in MLA. Closing this issue, please feel free to re-open