torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

Only half of parameters are saved when applied PP

Open dmammfl opened this issue 1 year ago • 7 comments

I'm currently training Llama-3-8B model in 2 GPUs with Pipeline parallel only. However, when i save a checkpoint on each rank, half of that checkpoint is saved. (Layer 1 is saved, Layer 2 is not saved, Layer 3 is saved, Layer 4 is not saved ... Layer 15 is saved.)

I think dcp.save only works well with dtensor, not tensor. I need your insight on this. Thanks a lot!

dmammfl avatar Jul 22 '24 23:07 dmammfl

maybe we should add state_dict hooks to PP to emit DTensor on PP's submesh so that DCP works with PP alone? @fegin @H-Huang @wconstab

wanchaol avatar Jul 23 '24 19:07 wanchaol

hmm. we shouldn't really need DTensor to solve the problem of layer0 being saved and layer1 not being saved. The fqn should be preserved and not conflict, so we should be able to save both. From the pattern I assume this is using virtual pipeline stages and layer 0,2,4,... are on gpu0 and only gpu0 is correctly saving things?

In the 3D case with PP, we expect that gpu0 would save DTensor including any TP/DP replication/sharding. However, we do not rely on DTensor for dealing with layer0 vs layer1.

wconstab avatar Aug 05 '24 19:08 wconstab

dcp.save() works with both DTensor and Tensor. Rank0 will determine what to save on each rank. If tensors are not duplicated (FQNs are different), all the tensors will be saved .

fegin avatar Aug 05 '24 19:08 fegin

I tested this case, and figured out several points:

  1. When only PP is applied in degree 2 and assume that the model is 15 GB, dcp.save should save 2 dcp checkpoints whose sizes are 7.5 GB, but each checkpoint's size is about 3.7 GB

  2. When applied only PP, model.state_dict( ) per each rank has its sharded model params exactly(rank 0 has 0~15 layer param, rank 1 has 16~31 layer param)

  3. in the case of rank 1, although rank 1 has 16~31 layer params, its key names are "model.layer.0.self_attn....", "model.layer.1.self_attn...." ..., "model.layer.15.self_attn....", exactly same as rank 0's layer key name (except embed_token, lm_head, etc)

  4. When I changed layer keynames as "PP0_model.layer.0.self_attn...." (did same thing for rank 1: "PP1_model.layer.0.self_attn...."), All of the state_dicts are saved properly, whose sizes are around 7.5 GB each.

I think there is a key confilct in _save_state_dict( ) method, so _save_state_dict( ) in dcp.save( ) does weird operation.

dmammfl avatar Aug 06 '24 08:08 dmammfl

  1. looks suspicious and just as you mentioned, there are key conflicts. We have tested the non-virtual pipeline and there are non key conflict. Any insight about this, @wconstab, @H-Huang ?

fegin avatar Aug 06 '24 16:08 fegin

could you share the exact repro command so we can debug?

wconstab avatar Aug 06 '24 16:08 wconstab

I run "run_llama_train.sh", setting "pipeline_parallel_degree" as 2 and other parallel degrees are 1.

dmammfl avatar Aug 08 '24 02:08 dmammfl

sorry for not getting back to you sooner.

in the case of rank 1, although rank 1 has 16~31 layer params, its key names are "model.layer.0.self_attn....", "model.layer.1.self_attn...." ..., "model.layer.15.self_attn....", exactly same as rank 0's layer key name (except embed_token, lm_head, etc)

I think this is the clue. In torchtitan, we changed the model definition to use a ModuleDict instead of ModuleList for model.layers. The reason for doing this is so that we can delete layers 0..15 from 'model' on rank1, and the keys (fqns) would remain as 16..31. If you are observing fqn of 0..15 on rank1, that is probably the root of your problem. Our DCP design for saving PP does assume unique FQNs per PP rank. Same-named fqns on different PP ranks would lead to a DCP save issue.

I noticed you filed the issue in July, and the changes to make model layers a dict with unique fqn's happened earlier than that. So i'm not sure how this could have happened, but please let me know if you are still seeing the issue. (or if it's with your own fork and maybe I can help update your model to work).

In torchtitan today if I run this command LOG_RANK=0,1 NGPU=2 ./run_llama_train.sh --experimental.pipeline_parallel_degree 2

I can see that rank 0 has unique fqn's compared to rank 1:

[rank0]:2024-11-01 15:40:09,518 - root - INFO - PP rank 0 is building stage_idx 0 with start_layer None, stop_layer layers.4: model chunk 
[rank0]:Transformer(                                                                                                                         
[rank0]:  (tok_embeddings): Embedding(2256, 256)                                                                                             
[rank0]:  (layers): ModuleDict(                                                                                                              
[rank0]:    (0): TransformerBlock(                                                                                                           
[rank0]:      (attention): Attention(                                                                                                        
[rank0]:        (wq): Linear(in_features=256, out_features=256, bias=False)                                             
...                                                                                   
[rank0]:    )                                                                                                                                
[rank0]:    (1): TransformerBlock(                                
[rank0]:      (attention): Attention(                              
...
[rank1]:2024-11-01 15:40:10,443 - root - INFO - PP rank 1 is building stage_idx 1 with start_layer layers.4, stop_layer None: model chunk 
[rank1]:Transformer(                                                                                                                         
[rank1]:  (tok_embeddings): None                                                                                                             
[rank1]:  (layers): ModuleDict(                                                                                                              
[rank1]:    (4): TransformerBlock(                                                                                                           
[rank1]:      (attention): Attention(                                                                                                        
[rank1]:        (wq): Linear(in_features=256, out_features=256, bias=False)                                              
...                                  
[rank1]:    )                                                                                                                                
[rank1]:    (5): TransformerBlock(                                                                                                           
[rank1]:      (attention): Attention(                                                                                                        
...

If i enable checkpointing ./run_llama_train.sh --experimental.pipeline_parallel_degree 2 --checkpoint.enable_checkpoint --checkpoint.interval 10

I can see my checkpoint files as expected:

(pytorch-3.10) [[email protected] /data/users/whc/torchtitan (main)]$ ls -al outputs/checkpoint/step-10/
total 94960
drwxr-xr-x 1 whc users       66 Nov  1 15:45 .
drwxr-xr-x 1 whc users       14 Nov  1 15:45 ..
-rw-r--r-- 1 whc users 48376940 Nov  1 15:45 __0_0.distcp
-rw-r--r-- 1 whc users 48537788 Nov  1 15:45 __1_0.distcp
-rw-r--r-- 1 whc users   316189 Nov  1 15:45 .metadata

If i examine the metadata there, i see all layers are present:

import pickle

with open("outputs/checkpoint/step-10/.metadata", "rb") as f:
    x = x = pickle.load(f)
    sdm = x.state_dict_metadata
    # filter out a bunch of noise, just pick one particular weight and see that we get 8 copies of it as expected
    layer_keys = [k for k in sdm.keys() if "wk.weight" in k and "optim" not in k]

print(layer_keys)

Prints:

[‘model.layers.0.attention.wk.weight', 'model.layers.1.attention.wk.weight', 'model.layers.2.attention.wk.weight', 'model.layers.3.attention.wk.weight', 'model.layers.4.attention.wk.weight', 'model.layers.5.attention.wk.weight', 'model.layers.6.attention.wk.weight', 'model.layers.7.attention.wk.weight']

wconstab avatar Nov 01 '24 22:11 wconstab

Going to close the issue as 'can not reproduce' but feel free to reopen if you have additional input or questions! @dmammfl

wconstab avatar Nov 01 '24 22:11 wconstab