transformers icon indicating copy to clipboard operation
transformers copied to clipboard

`LlamaRotaryEmbedding` `inv_freq` buffer is left uninitialized by `init_empty_weights` + `load_checkpoint_and_dispatch`

Open ringohoffman opened this issue 1 year ago • 0 comments

System Info

  • transformers version: 4.46.0.dev0
  • Platform: Linux-5.15.0-122-generic-x86_64-with-glibc2.35
  • Python version: 3.10.14
  • Huggingface_hub version: 0.24.7
  • Safetensors version: 0.4.5
  • Accelerate version: 0.34.2
  • PyTorch version (GPU?): 2.5.0.dev20240912+cu124 (True)
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?: yes
  • GPU type: NVIDIA H100 80GB HBM3

Who can help?

@ArthurZucker

Information

  • [X] The official example scripts
  • [X] My own modified scripts

Tasks

  • [X] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [ ] My own task or dataset (give details below)

Reproduction

import accelerate
import torch
import torch.distributed.algorithms._checkpoint.checkpoint_wrapper
import torch.distributed.fsdp.wrap
import torch.optim
import transformers
import transformers.models.llama.modeling_llama

torch.cuda.set_device(device := torch.device("cuda:0"))
torch.set_default_dtype(dtype := torch.bfloat16)

with accelerate.init_empty_weights():
    config = transformers.LlamaConfig.from_pretrained(
        pretrained_model_path := "/models/meta-llama/llama_3.1/",
        attn_implementation="flash_attention_2",
        torch_dtype=dtype,
    )
    model = transformers.LlamaForCausalLM(config)
    assert isinstance(model, transformers.LlamaForCausalLM)

model.to_empty(device="cuda")
accelerate.load_checkpoint_and_dispatch(
    model,
    checkpoint=pretrained_model_path,
    dtype=dtype,
)

for name, submodule in model.named_modules():
    if isinstance(submodule, transformers.models.llama.modeling_llama.LlamaRotaryEmbedding):
        print(f"{name=}: {submodule.inv_freq.sum()=}")

name='model.layers.0.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.1.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.2.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.3.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.4.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.5.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.6.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.7.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.8.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.9.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.10.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.11.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.12.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.13.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.14.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.15.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.16.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.17.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.18.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.19.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.20.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.21.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.22.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.23.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.24.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.25.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.26.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.27.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.28.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.29.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.30.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.layers.31.self_attn.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)
name='model.rotary_emb': submodule.inv_freq.sum()=tensor(0., device='cuda:0', dtype=torch.float32)

This happens silently, and just leads to very strange model behavior (like poor generation).

Expected behavior

LlamaRotaryEmbedding already has some neatly defined initialization code.

https://github.com/huggingface/transformers/blob/b54109c7466f6e680156fbd30fa929e2e222d730/src/transformers/models/llama/modeling_llama.py#L119-L122

One idea would be to define reset_parameters() for LlamaRotaryEmbedding, call it in __init__ and also in load_checkpoint_and_dispatch on any submodules that define it.

ringohoffman avatar Oct 18 '24 06:10 ringohoffman