PiPPy
PiPPy copied to clipboard
Meta init llama then pipeline then materialize
Models can be big. Therefore we would need to:
- create the model's "skeleton" on meta device
- partition it so that it can fit on each device, and
- materialize each partition.
This is a demo based on model Llama-2-7b-chat-hf and its checkpoint on Hugging Face Model Hub.
Before running the script, please download the following files in the same directory as this script:
- pytorch_model.bin.index.json
- pytorch_model-00001-of-00002.bin
- pytorch_model-00002-of-00002.bin
Download link: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main
Your directory should look like this:
How to run this script:
$ python meta_init.py
I haven't used a distributed runtime, because I only have a MacBook at hand. But I tried to show how to load each stage module from HF checkpoints. Feel free to modify the script to run in a distributed way by distributing the for loop at [Note 3].
My torch version:
torch 2.5.0.dev20240722
I install it by:
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
Cc: @lessw2020 @muellerzr @SunMarc @H-Huang @wconstab @LucasLLC
Run logs:
(base) kw2501@kw2501-mbp llama % python meta_init.py
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(32000, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=4096, bias=False)
(v_proj): Linear(in_features=4096, out_features=4096, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
(up_proj): Linear(in_features=4096, out_features=11008, bias=False)
(down_proj): Linear(in_features=11008, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
(norm): LlamaRMSNorm()
)
(lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
world_size=2
layers_per_rank = 16
Loading weights into stage 0
Fully updated state dict
class GraphModule(torch.nn.Module):
def forward(self, input_ids: "i64[2, 4]"):
# No stacktrace found for following nodes
model = self.model(input_ids); input_ids = None
getitem: "i64[1, 4]" = model[0]
getitem_1: "f32[4, 5]" = model[1]
getitem_2: "f32[2, 4, 4096]" = model[2]; model = None
return (getitem_2, getitem, getitem_1)
Loading weights into stage 1
Fully updated state dict
class GraphModule(torch.nn.Module):
def forward(self, add_95: "f32[2, 4, 4096]", unsqueeze: "i64[1, 4]", mul: "f32[4, 5]"):
# No stacktrace found for following nodes
model: "f32[2, 4, 4096]" = self.model(mul, unsqueeze, add_95); mul = unsqueeze = add_95 = None
lm_head: "f32[2, 4, 32000]" = self.lm_head(model); model = None
# File: /opt/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:1194 in forward, code: logits = logits.float()
_to_copy_default_162: "f32[2, 4, 32000]" = torch.ops.aten._to_copy.default(lm_head, dtype = torch.float32); lm_head = None
return _to_copy_default_162