Is there more detailed documentation for HF AutoTP training?
For example, the relationship between global_rank and tp_rank / dp_rank, and how to manually all-gather parameters to a specific rank (can we use deepspeed.zero.GatheredParameters ?).
And for example, the list of supported models and how to use this feature with DeepSpeed's native API.
We are planning to try HF AutoTP in OpenRLHF.
Thanks
@inkcherry, @delock, FYI
Hi, @hijkzzz, glad to see you're interested in this. This setup is TP-first, for example, with 4 ranks (0,1,2,3) and tp_size=2. So: [0,1] and [2,3] are TP groups, [0,2] and [1,3] are DP groups
You can retrieve them using get_tensor_model_parallel_group() and get_data_parallel_group() and other apis from deepspeed.utils.groups.
To get the full param, you can use GatherReplacedLayerParams
https://github.com/deepspeedai/DeepSpeed/blob/9926879b5986d6903f61f2683be68ab8c48c9766/deepspeed/runtime/engine.py#L3718 note `It will perform partitioning again when exiting the context.
Hope this helps you. Feel free to discuss any other questions!
DeepSpeed/deepspeed/runtime/engine.py
@inkcherry
Why is the size of the final saved model the same when I use deepspeed.zero.GatheredParameters to gather the state_dict, compared to not enabling Auto TP?
see https://github.com/OpenRLHF/OpenRLHF/blob/1ad22284caf8b50d2a043abdebdfc3318d17bdae/openrlhf/utils/deepspeed/deepspeed.py#L324
hi @hijkzzz , Are you referring to the fact that the saved weights are the same between the following two cases?
-
Training with AutoTP, saved using deepspeed.zero.GatheredParameters
-
Training without AutoTP
This is somewhat unreasonable, as the usage in case 1 is not the intended behavior.
for TP partition model:
If you are not using the transformer Trainer, please make sure that the TP initialization function is called.
Reference: https://github.com/huggingface/transformers/pull/36825/files
You can manually partition the model by calling it after from_pretrained and before training.
for save model:
deepspeed.zero.GatheredParameters is expected to be mainly used for Zero3, it should not work for TP or TP combined with Zero1/2 partitioning. GatherReplacedLayerParams is for tp.
If you are using the Huggingface Trainer's save_model, the following functions in deepspeed/runtime/engine.py should be automatically called to save full model by accelerate
When using Zero3:
def _zero3_consolidated_16bit_state_dict(self, exclude_frozen_parameters=False)
when using tp or tp+zero.
def _replace_module_consolidated_state_dict(self):
If you are using a custom save_model function, you can also refer to the gather implementations in these two functions.
hi @hijkzzz , Are you referring to the fact that the saved weights are the same between the following two cases?
- Training with AutoTP, saved using deepspeed.zero.GatheredParameters
- Training without AutoTP
This is somewhat unreasonable, as the usage in case 1 is not the intended behavior.
for TP partition model:
If you are not using the transformer Trainer, please make sure that the TP initialization function is called. Reference: https://github.com/huggingface/transformers/pull/36825/files You can manually partition the model by calling it after
from_pretrainedand beforetraining.for save model:
deepspeed.zero.GatheredParametersis expected to be mainly used for Zero3, it should not work for TP or TP combined with Zero1/2 partitioning.GatherReplacedLayerParamsis for tp.If you are using the Huggingface Trainer's
save_model, the following functions indeepspeed/runtime/engine.pyshould be automatically called to save full model byaccelerateWhen using Zero3:
def _zero3_consolidated_16bit_state_dict(self, exclude_frozen_parameters=False)when using tp or tp+zero.
def _replace_module_consolidated_state_dict(self):If you are using a custom save_model function, you can also refer to the gather implementations in these two functions.
Thanks
I don't see get_tensor_model_parallel_group() used in either transformers or accelerate in the context of deepspeed (just megatron in accelereate), so I'm not sure how this should work in the HF ecosystem. Currently I'm getting AssertionError: Data inconsistency within the TP group. Please check the Dataloader implementation to ensure consistency.
I don't see
get_tensor_model_parallel_group()used in either transformers or accelerate in the context of deepspeed (just megatron in accelereate), so I'm not sure how this should work in the HF ecosystem. Currently I'm gettingAssertionError: Data inconsistency within the TP group. Please check the Dataloader implementation to ensure consistency.
hi, have you solved this problem?
I don't see
get_tensor_model_parallel_group()used in either transformers or accelerate in the context of deepspeed (just megatron in accelereate), so I'm not sure how this should work in the HF ecosystem. Currently I'm gettingAssertionError: Data inconsistency within the TP group. Please check the Dataloader implementation to ensure consistency.hi, have you solved this problem?
I finally found the cause of the bug. If you have some data augmentation operations, make sure the random seed is the same in the same tp group