Megatron-LM
Megatron-LM copied to clipboard
[BUG] `get_gpt_layer_local_spec` fails to initialize correct `DotProductAttention` in MLA mode
Describe the Bug
The function get_gpt_layer_local_spec claims to support multi_latent_attention=True, but it fails to initialize a correct DotProductAttention instance. Specifically, it does not account for k_channels and v_channels, which are required for MLASelfAttention.
To Reproduce
Run the following minimal script:
import torch
from megatron.core import parallel_state
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
from megatron.core.tensor_parallel import model_parallel_cuda_manual_seed
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.multi_latent_attention import MLASelfAttention
from megatron.core.transformer.transformer_config import MLATransformerConfig
random_seed = 42
torch.distributed.init_process_group(
init_method="tcp://127.0.0.1:12340",
rank=0,
world_size=1,
)
parallel_state.initialize_model_parallel(1, 1)
model_parallel_cuda_manual_seed(random_seed)
if __name__ == "__main__":
config = MLATransformerConfig(
num_layers=1,
hidden_size=512,
num_attention_heads=8,
)
attention = MLASelfAttention(
config,
submodules=get_gpt_layer_local_spec(multi_latent_attention=True).submodules.self_attention.submodules,
layer_number=0,
attn_mask_type=AttnMaskType.causal,
)
Expected Behavior
The call to MLASelfAttention should succeed when using the submodules from get_gpt_layer_local_spec(multi_latent_attention=True).
Environment
-
Megatron-LM commit:
core_r0.11.0 -
PyTorch version:
2.6.0 -
CUDA version:
12.5 - NCCL version: 2.21.5
💡 Proposed Fix
The DotProductAttention module returned by the local spec should accept and properly initialize k_channels and v_channels, which are required for MLA-based attention mechanisms.