Qwen parallelizer with sequence parallelism
Describe the bug
With sequence_parallel = True, embed_tokens are sharded on sequence dimension, see here. Then if position_ids is not provided as an explicit argument to forward, position_ids will us https://github.com/huggingface/transformers/blob/e20df45bf676d80bdddb9757eeeafe6c0c81ecfa/src/transformers/models/qwen3/modeling_qwen3.py#L380-L381 the cache_position's shape to create the position_ids, however, the cache_position's shape is based on the local embed_tokens shape.
Steps/Code to reproduce bug
Please list minimal steps or code snippet for us to be able to reproduce the bug.
A helpful guide on on how to craft a minimal bug report http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports.
Expected behavior
A clear and concise description of what you expected to happen.
Additional context
Add any other context about the problem here.
in RL we are seeing that updating from torch 2.7.1 -> torch 2.8 causes a failure like this:
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
return inner()
^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1827, in inner
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 125, in forward
return F.linear(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/tensor/_api.py", line 350, in __torch_dispatch__
return DTensor._op_dispatcher.dispatch(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py", line 154, in dispatch
self.sharding_propagator.propagate(op_info)
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/tensor/_sharding_prop.py", line 266, in propagate
OutputSharding, self.propagate_op_sharding(op_info.schema)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/tensor/_sharding_prop.py", line 45, in __call__
return self.cache(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/tensor/_sharding_prop.py", line 286, in propagate_op_sharding_non_cached
op_strategy = self.op_strategy_funcs[op_schema.op](strategy_schema)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_view_ops.py", line 646, in reshape_strategy
input_tgt_placements, output_placements = propagate_shape_and_sharding(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_view_ops.py", line 596, in propagate_shape_and_sharding
in_dim = get_in_dim_to_shard(cmd)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_view_ops.py", line 532, in get_in_dim_to_shard
raise RuntimeError(
RuntimeError: ('Attempted to flatten sharded dimension 1, ', 'but only the leftmost dim of a Flatten can be sharded.')
whenever SP is enabled, we see it w/ llama || qwen
https://github.com/NVIDIA-NeMo/RL/pull/1557 this should fix the issue simply, using the commit in https://github.com/NVIDIA-NeMo/Automodel/pull/804/commits/282aca0927a5f017acf9c7e577075efce6145b26
llama model seems have a different error messge:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 237, in forward
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 135, in apply_rotary_pos_emb
q_embed = (q * cos) + (rotate_half(q) * sin)
~~^~~~~
RuntimeError: The size of tensor a (380) must match the size of tensor b (379) at non-singleton dimension 2
- sharding annotation mismatch issue: RuntimeError: ('Attempted to flatten sharded dimension 1, ', 'but only the leftmost dim of a Flatten can be sharded.')
- https://github.com/NVIDIA-NeMo/RL/pull/1557 should address sharding annotation issue
- dynamic sequence in batches with potentially uneven seq length split: RuntimeError: The size of tensor a (xxx) must match the size of tensor b (xxx) at non-singleton dimension 2
- this is a more complicated and no-trivial one as
seq_length_is_variable=Trueimplementation in deepspeed:- https://huggingface.co/docs/accelerate/v1.12.0/concept_guides/sequence_parallelism#alstulysses-sp-backend-configuration
- https://github.com/deepspeedai/DeepSpeed/blob/53e91a098d0a0666ac8cb8025a5b36e5af172d08/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L241-L257
- this is a more complicated and no-trivial one as