DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[BUG] AutoTP fails for Qwen2.5 models when tp size > 2

Open HollowMan6 opened this issue 8 months ago • 1 comments

Describe the bug

Tensor size mismatch for k_proj using Qwen/Qwen2.5-3B-Instruct when running with 8 GPUs, TP=8, DP=1, error logs:

  File "openrlhf/trainer/ray/launcher.py", line 98, in execute_batch
    result = func(**sample_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "openrlhf/trainer/ray/ppo_actor.py", line 486, in forward
    action_log_probs = self.actor(
                       ^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "openrlhf/models/actor.py", line 195, in forward
    output = self.model(sequences, attention_mask=foward_attention_mask, position_ids=position_ids)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "deepspeed/utils/nvtx.py", line 20, in wrapped_fn
    ret_val = func(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^
  File "deepspeed/runtime/engine.py", line 2054, in forward
    loss = self.module(*inputs, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1845, in _call_impl
    return inner()
           ^^^^^^^
  File "torch/nn/modules/module.py", line 1793, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "transformers/utils/generic.py", line 965, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "transformers/models/qwen2/modeling_qwen2.py", line 824, in forward
    outputs: BaseModelOutputWithPast = self.model(
                                       ^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "transformers/utils/generic.py", line 965, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "transformers/models/qwen2/modeling_qwen2.py", line 550, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "transformers/models/qwen2/modeling_qwen2.py", line 263, in forward
    hidden_states, self_attn_weights = self.self_attn(
                                       ^^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "transformers/models/qwen2/modeling_qwen2.py", line 166, in forward
    key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: shape '[1, 1862, -1, 128]' is invalid for input of size 59584

To Reproduce

I'm using OpenRLHF with AutoTP, but I think this should not be directly related to code in OpenRLHF and should be general enough to reproduce.

Expected behavior No error thrown

ds_report output

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-devel package with yum
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
dc ..................... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
 [WARNING]  FP Quantizer is using an untested triton version (3.2.0), only 2.3.(0, 1) and 3.0.0 are known to be compatible with these kernels
fp_quantizer ........... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
 [WARNING]  gds requires the dev libaio .so object and headers but these were not found.
 [WARNING]  gds: please install the libaio-devel package with yum
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
gds .................... [NO] ....... [NO]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.6
 [WARNING]  using untested triton version (3.2.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch version .................... 2.6.0+cu124
deepspeed info ................... 0.16.6, unknown, unknown
torch cuda version ............... 12.4
torch hip version ................ None
nvcc version ..................... 12.4
deepspeed wheel compiled w. ...... torch 2.6, cuda 12.4
shared memory (/dev/shm) size .... 251.75 GB

System info (please complete the following information):

  • OS: RHEL 9.5
  • 1 machines with x4 A100s each
  • transformers==4.51.1
  • flash_attn==2.7.4.post1
  • Python version: 3.12.9
  • OpenRLHF version: 0.7.4.post2

HollowMan6 avatar Apr 28 '25 16:04 HollowMan6

In Qwen2.5-3B, the kv_head number is 2. If you want to set kv_head > 2, you would need to replicate the KV heads, and if you only do inference, the KV cache would also be duplicated.

Although a similar replication approach has been implemented for ChatGLM, it seems that Qwen2.5 has not yet supported this. At this stage, I suggest we keep kv_head divisible by tp_size.

inkcherry avatar Apr 29 '25 02:04 inkcherry

Closing this now as this has been explained.

HollowMan6 avatar Oct 22 '25 22:10 HollowMan6