transformers
                                
                                 transformers copied to clipboard
                                
                                    transformers copied to clipboard
                            
                            
                            
                        `LlamaRotaryEmbedding` `inv_freq` buffer is left uninitialized by `init_empty_weights` + `load_checkpoint_and_dispatch`
System Info
- transformersversion: 4.46.0.dev0
- Platform: Linux-5.15.0-122-generic-x86_64-with-glibc2.35
- Python version: 3.10.14
- Huggingface_hub version: 0.24.7
- Safetensors version: 0.4.5
- Accelerate version: 0.34.2
- PyTorch version (GPU?): 2.5.0.dev20240912+cu124 (True)
- Using distributed or parallel set-up in script?: 
- Using GPU in script?: yes
- GPU type: NVIDIA H100 80GB HBM3
Who can help?
@ArthurZucker
Information
- [X] The official example scripts
- [X] My own modified scripts
Tasks
- [X] An officially supported task in the examplesfolder (such as GLUE/SQuAD, ...)
- [ ] My own task or dataset (give details below)
Reproduction
import accelerate
import torch
import torch.distributed.algorithms._checkpoint.checkpoint_wrapper
import torch.distributed.fsdp.wrap
import torch.optim
import transformers
import transformers.models.llama.modeling_llama
torch.cuda.set_device(device := torch.device("cuda:0"))
torch.set_default_dtype(dtype := torch.bfloat16)
with accelerate.init_empty_weights():
    config = transformers.LlamaConfig.from_pretrained(
        pretrained_model_path := "/models/meta-llama/llama_3.1/",
        attn_implementation="flash_attention_2",
        torch_dtype=dtype,
    )
    model = transformers.LlamaForCausalLM(config)
    assert isinstance(model, transformers.LlamaForCausalLM)
model.to_empty(device="cuda")
accelerate.load_checkpoint_and_dispatch(
    model,
    checkpoint=pretrained_model_path,
    dtype=dtype,
)
for name, submodule in model.named_modules():
    if isinstance(submodule, transformers.models.llama.modeling_llama.LlamaRotaryEmbedding):
        print(f"{name=}: {submodule.inv_freq.sum()=}")
name='model.layers.0.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.1.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.2.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.3.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.4.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.5.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.6.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.7.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.8.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.9.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.10.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.11.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.12.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.13.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.14.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.15.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.16.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.17.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.18.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.19.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.20.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.21.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.22.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.23.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.24.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.25.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.26.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.27.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.28.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.29.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.30.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.31.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
This happens silently, and just leads to very strange model behavior (like poor generation).
Expected behavior
LlamaRotaryEmbedding already has some neatly defined initialization code.
https://github.com/huggingface/transformers/blob/b54109c7466f6e680156fbd30fa929e2e222d730/src/transformers/models/llama/modeling_llama.py#L119-L122
One idea would be to define reset_parameters() for LlamaRotaryEmbedding, call it in __init__ and also in load_checkpoint_and_dispatch on any submodules that define it.