DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[AutoTP + DS2] no memory reduction when using auto = 4

Open jiaqiw09 opened this issue 1 month ago • 4 comments

I trained model qwen2.5-7B in GPU with llamafactory deepspeed zero2 + autotp. And there is no obvious memory reduction.

When individually using zero2, the average memory of 8 cards is 42g, while the average of zero2+autotp=4 is also nearly 42.

Here is my deepspeed config.


{
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto",
  "gradient_accumulation_steps": "auto",
  "gradient_clipping": "auto",
  "zero_allow_untested_optimizer": true,
  "fp16": {
    "enabled": "auto",
    "loss_scale": 0,
    "loss_scale_window": 1000,
    "initial_scale_power": 16,
    "hysteresis": 2,
    "min_loss_scale": 1
  },
  "bf16": {
    "enabled": "auto"
  },
  "zero_optimization": {
    "stage": 2,
    "allgather_partitions": true,
    "allgather_bucket_size": 5e8,
    "overlap_comm": false,
    "reduce_scatter": true,
    "reduce_bucket_size": 5e8,
    "contiguous_gradients": true,
    "round_robin_gradients": true
  },
  "tensor_parallel":{
      "autotp_size": 4
  }
}

jiaqiw09 avatar Nov 11 '25 03:11 jiaqiw09

https://www.deepspeed.ai/tutorials/automatic-tensor-parallelism/#supported-models, it looks like the qwen2.5 is not in the supported model list

@delock, FYI

xylian86 avatar Nov 14 '25 16:11 xylian86

https://www.deepspeed.ai/tutorials/automatic-tensor-parallelism/#supported-models, it looks like the qwen2.5 is not in the supported model list

@delock, FYI

Just verified that AutoTP support Qwen2.5-7B, the list should be updated. Will submit a PR.

There should be other reason why there is no memory reduction, need further investigation.

@inkcherry for comments.

delock avatar Nov 17 '25 07:11 delock

I trained model qwen2.5-7B in GPU with llamafactory deepspeed zero2 + autotp. And there is no obvious memory reduction.

When individually using zero2, the average memory of 8 cards is 42g, while the average of zero2+autotp=4 is also nearly 42.

Here is my deepspeed config.


{
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto",
  "gradient_accumulation_steps": "auto",
  "gradient_clipping": "auto",
  "zero_allow_untested_optimizer": true,
  "fp16": {
    "enabled": "auto",
    "loss_scale": 0,
    "loss_scale_window": 1000,
    "initial_scale_power": 16,
    "hysteresis": 2,
    "min_loss_scale": 1
  },
  "bf16": {
    "enabled": "auto"
  },
  "zero_optimization": {
    "stage": 2,
    "allgather_partitions": true,
    "allgather_bucket_size": 5e8,
    "overlap_comm": false,
    "reduce_scatter": true,
    "reduce_bucket_size": 5e8,
    "contiguous_gradients": true,
    "round_robin_gradients": true
  },
  "tensor_parallel":{
      "autotp_size": 4
  }
}

Hi @jiaqiw09 what is the actual train_batch_size used? Is fp16 enabled or bf16 enabled. I'm tring to get an observation of this.

delock avatar Nov 17 '25 08:11 delock

I trained model qwen2.5-7B in GPU with llamafactory deepspeed zero2 + autotp. And there is no obvious memory reduction. When individually using zero2, the average memory of 8 cards is 42g, while the average of zero2+autotp=4 is also nearly 42. Here is my deepspeed config.


{
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto",
  "gradient_accumulation_steps": "auto",
  "gradient_clipping": "auto",
  "zero_allow_untested_optimizer": true,
  "fp16": {
    "enabled": "auto",
    "loss_scale": 0,
    "loss_scale_window": 1000,
    "initial_scale_power": 16,
    "hysteresis": 2,
    "min_loss_scale": 1
  },
  "bf16": {
    "enabled": "auto"
  },
  "zero_optimization": {
    "stage": 2,
    "allgather_partitions": true,
    "allgather_bucket_size": 5e8,
    "overlap_comm": false,
    "reduce_scatter": true,
    "reduce_bucket_size": 5e8,
    "contiguous_gradients": true,
    "round_robin_gradients": true
  },
  "tensor_parallel":{
      "autotp_size": 4
  }
}

Hi @jiaqiw09 what is the actual train_batch_size used? Is fp16 enabled or bf16 enabled. I'm tring to get an observation of this.

I use bfloat16, and 8 card while gbs is batch_size_per_card is 1

jiaqiw09 avatar Nov 17 '25 13:11 jiaqiw09