PiPPy icon indicating copy to clipboard operation
PiPPy copied to clipboard

Meta init llama then pipeline then materialize

Open kwen2501 opened this issue 1 year ago • 1 comments

Models can be big. Therefore we would need to:

  • create the model's "skeleton" on meta device
  • partition it so that it can fit on each device, and
  • materialize each partition.

This is a demo based on model Llama-2-7b-chat-hf and its checkpoint on Hugging Face Model Hub.

Before running the script, please download the following files in the same directory as this script:

  • pytorch_model.bin.index.json
  • pytorch_model-00001-of-00002.bin
  • pytorch_model-00002-of-00002.bin

Download link: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main

Your directory should look like this: Screenshot 2024-07-23 at 7 44 35 AM

How to run this script: $ python meta_init.py

I haven't used a distributed runtime, because I only have a MacBook at hand. But I tried to show how to load each stage module from HF checkpoints. Feel free to modify the script to run in a distributed way by distributing the for loop at [Note 3].

My torch version: torch 2.5.0.dev20240722 I install it by: pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu

Cc: @lessw2020 @muellerzr @SunMarc @H-Huang @wconstab @LucasLLC

kwen2501 avatar Jul 23 '24 15:07 kwen2501

Run logs:

(base) kw2501@kw2501-mbp llama % python meta_init.py

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
world_size=2
layers_per_rank = 16
Loading weights into stage 0
Fully updated state dict
class GraphModule(torch.nn.Module):
    def forward(self, input_ids: "i64[2, 4]"):
        # No stacktrace found for following nodes
        model = self.model(input_ids);  input_ids = None
        getitem: "i64[1, 4]" = model[0]
        getitem_1: "f32[4, 5]" = model[1]
        getitem_2: "f32[2, 4, 4096]" = model[2];  model = None
        return (getitem_2, getitem, getitem_1)
        
Loading weights into stage 1
Fully updated state dict
class GraphModule(torch.nn.Module):
    def forward(self, add_95: "f32[2, 4, 4096]", unsqueeze: "i64[1, 4]", mul: "f32[4, 5]"):
        # No stacktrace found for following nodes
        model: "f32[2, 4, 4096]" = self.model(mul, unsqueeze, add_95);  mul = unsqueeze = add_95 = None
        lm_head: "f32[2, 4, 32000]" = self.lm_head(model);  model = None
        
         # File: /opt/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:1194 in forward, code: logits = logits.float()
        _to_copy_default_162: "f32[2, 4, 32000]" = torch.ops.aten._to_copy.default(lm_head, dtype = torch.float32);  lm_head = None
        return _to_copy_default_162

kwen2501 avatar Jul 23 '24 15:07 kwen2501