[BUG] Multi-Head-Latent-Attention Error
Describe the bug
MLA reports AssertionError: Keys and values must have the same shape! when trying to reproduce DeepSeek2/3, which is caused by shape mismatching of key and vlaue input to TEDotProductAttention.
MLA shape before adjust, query: torch.Size([4096, 1, 128, 192]), key: torch.Size([4096, 1, 128, 192]), value: torch.Size([4096, 1, 128, 128])
To Reproduce
MLA_ARGS=(
--multi-latent-attention
--qk-pos-emb-head-dim 64
--qk-head-dim 128
--q-lora-rank 1536
--kv-lora-rank 512
--v-head-dim 128
--qk-layernorm
)
...
torchrun ${DISTRIBUTED_ARGS[@]} pretrain_gpt.py \
${MODEL_ARGS[@]} \
${MLA_ARGS[@]} \
${MOE_ARGS[@]} \
${DATA_ARGS[@]} \
${TRAINING_ARGS[@]} \
${MODEL_PARALLEL_ARGS[@]} \
${LOGGING_ARGS[@]}
Expected behavior
Stack trace/logs
Environment (please complete the following information):
- Megatron-LM 8ca9e57f9d0bb93fc61850ebdccb6b6e6fa36b64
- PyTorch version 2.2
- CUDA version 12.1
- NCCL version
Proposed fix
The value should be padded with zeros in get_query_key_value_tensors function in MLASelfAttention module.
Additional context Add any other context about the problem here.
I met the same issue when I just use this option: --multi-latent-attention
Thanks for reporting the issue, could you share your TE version?
We are sorry but we could not reproduce the bug. Would you please share your backtrace with us, or you may update your TE version to 2.x and try again.
Marking as stale. No activity in 60 days.
This issue was closed because it has been inactive for 7 days since being marked as stale.
@AlbertZhangHIT @larrylu0426 are you still facing this issue?