DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[BUG] ZERO++ | AssertionError: ZeRO parameter intra parallel group is already initialized

Open dhkim0225 opened this issue 1 year ago • 3 comments

Describe the bug Hello. I'm an active user of deepspeed for multi-node training.

I've always used zero3, but this time I tried attaching the hpz feature of zero++ for the first time. The issue is that the structure of the class I'm using looks like this.

class MYLM(GPT2LMHeadModel):
    def __init__(self, config, sub_model_name_or_path):
        super().__init__(config)
        self.sub_model = SubModel.from_pretrained(sub_model_name_or_path)
        input_hidden_size = self.sub_model.text_decoder.config.d_model
        self.mm_projector = nn.Linear(input_hidden_size, config.hidden_size)

It may not seem problematic at first glance, but the sub_model is a pretrained model using deepspeed zero3, and the deepspeed.zero.Init() function is called through the from_pretrained() function of transformers library.

https://github.com/huggingface/transformers/blob/5d36025ca13d05151b7a0c761e90d429c4644a30/src/transformers/modeling_utils.py#L3479-L3494

However, since hpz partitioning has already done, the deepspeed.zero.Init() function inside from_pretrained() throws an error.

https://github.com/microsoft/DeepSpeed/blob/834272531aa4368f793cc78418612e1e09166094/deepspeed/utils/groups.py#L518

What can I do to resolve this issue? Thank you.

Expected behavior HPZ + zero3 config should work!

ds_report output

[2024-01-05 10:48:07,604] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.0
 [WARNING]  using untested triton version (2.0.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/nsml/.local/lib/python3.8/site-packages/torch']
torch version .................... 2.0.1+cu117
deepspeed install path ........... ['/home/nsml/.local/lib/python3.8/site-packages/deepspeed']
deepspeed info ................... 0.12.6, unknown, unknown
torch cuda version ............... 11.7
torch hip version ................ None
nvcc version ..................... 12.1
deepspeed wheel compiled w. ...... torch 2.0, cuda 11.7
shared memory (/dev/shm) size .... 177.05 GB

Screenshots

│ /home/nsml/.local/lib/python3.8/site-packages/transformers/modeling_utils.py:3077 in             │
│ from_pretrained                                                                                  │
│                                                                                                  │
│   3074 │   │   │   import deepspeed                                                              │
│   3075 │   │   │                                                                                 │
│   3076 │   │   │   logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this mode  │
│ ❱ 3077 │   │   │   init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())  │
│   3078 │   │   elif load_in_8bit or load_in_4bit or low_cpu_mem_usage:                           │
│   3079 │   │   │   init_contexts.append(init_empty_weights())                                    │
│   3080                                                                                           │
│                                                                                                  │
│ /home/nsml/.local/lib/python3.8/site-packages/deepspeed/runtime/zero/partition_parameters.py:884 │
│ in __init__                                                                                      │
│                                                                                                  │
│    881 │   │                                                                                     │
│    882 │   │   self.zero_param_process_group = zero_param_parallel_group                         │
│    883 │   │   if _ds_config is not None and _ds_config.zero_config.zero_hpz_partition_size > 1  │
│ ❱  884 │   │   │   groups._create_zero_param_parallel_group(_ds_config.zero_config.zero_hpz_par  │
│    885 │   │   │   self.zero_param_process_group = groups._get_zero_param_intra_parallel_group(  │
│    886 │   │                                                                                     │
│    887 │   │   self.num_ranks_in_param_group = self.dp_world_size                                │
│                                                                                                  │
│ /home/nsml/.local/lib/python3.8/site-packages/deepspeed/utils/groups.py:518 in                   │
│ _create_zero_param_parallel_group                                                                │
│                                                                                                  │
│   515 │   assert dist.is_initialized()                                                           │
│   516 │   global _ZERO_PARAM_INTRA_PARALLEL_GROUP                                                │
│   517 │   # Only create group if it does not already exist                                       │
│ ❱ 518 │   assert _ZERO_PARAM_INTRA_PARALLEL_GROUP is None, \                                     │
│   519 │   │   'ZeRO parameter intra parallel group is already initialized'                       │
│   520 │                                                                                          │
│   521 │   world_size = dist.get_world_size()                                                     │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AssertionError: ZeRO parameter intra parallel group is already initialized

System info (please complete the following information):

  • OS: Ubuntu 20.04
  • GPU count and types: two machines with x4 V100s each
  • Interconnects (if applicable)
  • Python version: 3.8.10
  • Any other relevant info about your setup: transformers 4.34.1

Launcher context accelerate (from huggingface)

Docker context nope.

Additional context

dhkim0225 avatar Jan 05 '24 10:01 dhkim0225