llm-foundry icon indicating copy to clipboard operation
llm-foundry copied to clipboard

LLaMA PRO training resume problem

Open germanjke opened this issue 1 year ago • 6 comments

Hello,

I'm currently training LLaMA PRO. Initially, I expanded the model from 32 layers to 40 layers and proceeded to train only the newly added 8 layers (every fifth layer). Therefore, I froze 32 out of the 40 layers.

layer_freezing: 
    layer_names: [ 
    'model._fsdp_wrapped_module.model.layers.36._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.16._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.18._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.1._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.27._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.32._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.35._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.0._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.10._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.3._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.37._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.28._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.22._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.12._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.2._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.5._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.8._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.20._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.17._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.25._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.30._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.38._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.7._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.33._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.6._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.31._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.13._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.15._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.11._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.21._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.26._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.23._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param'
    ]

The training is going well and only the layers I need are trained.

But after following a hardware failure, I attempted to resume training using load_path, but I encountered an error:

[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/state_dict_loader.p
[rank6]: y", line 198, in local_step
[rank6]:     local_plan = planner.create_local_plan()
[rank6]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/default_planner.py"
[rank6]: , line 185, in create_local_plan
[rank6]:     return create_default_local_load_plan(self.state_dict, self.metadata)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/default_planner.py"
[rank6]: , line 235, in create_default_local_load_plan
[rank6]:     md = metadata.state_dict_metadata[fqn]
[rank6]:          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^
[rank6]: KeyError: 
[rank6]: 'state.optimizers.DecoupledAdamW.state.model.model.embed_tokens.weight.exp_avg'
[rank6]: Traceback (most recent call last): (RANK 14)
[rank6]:   File "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/utils.py", 
[rank6]: line 163, in reduce_scatter
[rank6]:     local_data = map_fun()
[rank6]:                  ^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/state_dict_loader.p
[rank6]: y", line 198, in local_step
[rank6]:     local_plan = planner.create_local_plan()
[rank6]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/default_planner.py"
[rank6]: , line 185, in create_local_plan
[rank6]:     return create_default_local_load_plan(self.state_dict, self.metadata)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/default_planner.py"
[rank6]: , line 235, in create_default_local_load_plan
[rank6]:     md = metadata.state_dict_metadata[fqn]
[rank6]:          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^
[rank6]: KeyError: 
[rank6]: 'state.optimizers.DecoupledAdamW.state.model.model.embed_tokens.weight.exp_avg'
[rank6]: Traceback (most recent call last): (RANK 15)
[rank6]:   File "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/utils.py", 
[rank6]: line 163, in reduce_scatter
[rank6]:     local_data = map_fun()
[rank6]:                  ^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/state_dict_loader.p
[rank6]: y", line 198, in local_step
[rank6]:     local_plan = planner.create_local_plan()
[rank6]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/default_planner.py"
[rank6]: , line 185, in create_local_plan
[rank6]:     return create_default_local_load_plan(self.state_dict, self.metadata)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/default_planner.py"
[rank6]: , line 235, in create_default_local_load_plan
[rank6]:     md = metadata.state_dict_metadata[fqn]
[rank6]:          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^
[rank6]: KeyError: 
[rank6]: 'state.optimizers.DecoupledAdamW.state.model.model.embed_tokens.weight.exp_avg'

My ep0-ba4500/.metadata looks like this:

���������%torch.distributed.checkpoint.metadata��Metadata���)��}�(�state_dict_metadata�}�(�+state.model.model.model.embed_tokens.weight�h��TensorStorageMetadata���)��}�(�
properties�h��TensorProperties���)��(�torch��float32����torch.serialization��_get_layout����
torch.strided���R��h��_MEM_FORMAT_ENCODING���K���R��t�b�size��torch��Size���J���M�����R��chunks�]�(h��ChunkStorageMetadata���)��}�(�offsets�h!K�K�����R��sizes�h!M�>M�����R�ubh()��}�(h+h!M��K�����R�h/h!M�>M�����R�ubh()��}�(h+h!J�w�K�����R�h/h!M�>M�����R�ubh()��}�(h+h!M�>K�����R�h/h!M�>M�����R�ubh()��}�(h+h!M@}K�����R�h/h!M�>M�����R�ubh()��}�(h+h!M�K�����R�h/h!M�>M�����R�ubh()��}�(h+h!J 9�K�����R�h/h!M�>M�����R�ubh()��}�(h+h!J`��K�����R�h/h!M�>M�����R�ubeub�3state.model.model.model.layers.2.mlp.up_proj.weight�h	)��}�(hh)��(hh�h�t�bhh!M�8M�����R�h%]�(h()��}�(h+h!K�K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�*K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�#K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�1K�����R�h/h!M�M�����R�ubeub�7state.model.model.model.layers.2.input_layernorm.weight�h	)��}�(hh)��(hh�h�t�bhh!M�����R�h%]�(h()��}�(h+h!K�����R�h/h!M�����R�ubh()��}�(h+h!M�����R�h/h!M�����R�ubh()��}�(h+h!M�����R�h/h!M�����R�ubh()��}�(h+h!M�����R�h/h!M�����R�ubh()��}�(h+h!M�����R�h/h!M�����R�ubh()��}�(h+h!M�����R�h/h!M�����R�ubh()��}�(h+h!M�
����R�h/h!M�����R�ubh()��}�(h+h!M�����R�h/h!M�����R�ubeub�@state.model.model.model.layers.2.post_attention_layernorm.weight�h	)��}�(hh)��(hh�h�t�bhh!M�����R�h%]�(h()��}�(h+h!K�����R�h/h!M�����R�ubh()��}�(h+h!M�����R�h/h!M�����R�ubh()��}�(h+h!M�����R�h/h!M�����R�ubh()��}�(h+h!M�����R�h/h!M�����R�ubh()��}�(h+h!M�����R�h/h!M�����R�ubh()��}�(h+h!M�����R�h/h!M�����R�ubh()��}�(h+h!M�
����R�h/h!M�����R�ubh()��}�(h+h!M�����R�h/h!M�����R�ubeub�8state.model.model.model.layers.3.self_attn.q_proj.weight�h	)��}�(hh)��(hh�h�t�bhh!M�M�����R�h%]�(h()��}�(h+h!K�K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�
K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�K�����R�h/h!M�M�����R�ubeub�8state.model.model.model.layers.3.self_attn.o_proj.weight�h	)��}�(hh)��(hh�h�t�bhh!M�M�����R�h%]�(h()��}�(h+h!K�K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�
K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�K�����R�h/h!M�M�����R�ubeub�3state.model.model.model.layers.3.mlp.up_proj.weight�h	)��}�(hh)��(hh�h�t�bhh!M�8M�����R�h%]�(h()��}�(h+h!K�K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�*K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�#K�����R�h/h!M�M�����R�ubh()��}�(h+h!M�1K�����R�h/h!M�M�����R�ubeub�7state.model.model.model.layers.3.input_layernorm.weight�h	)��}�(hh)��
etc...

Have you experienced similar issues?

germanjke avatar May 23 '24 12:05 germanjke