fix(hooks): Add padding support to context parallel hooks
What does this PR do?
This PR now modifies the ContextParallelSplitHook and ContextParallelGatherHook to gracefully handle sequence lengths that are not divisible by the world size.
This PR changes:
- Generic Padding: The ContextParallelSplitHook now pads any input tensor to a divisible length before sharding.
- State Management: It temporarily stores the original sequence length on the module instance itself.
- Generic Trimming: The ContextParallelGatherHook uses this stored length to trim the padding from the final output tensor before returning it.
This ensures that the padding is completely transparent to the model and the end-user, preventing crashes without altering the output shape. The fix is now contained entirely within the hooks and requires no changes to the Qwen transformer or any other model.
I have also added a new unit test in tests/hooks/test_hooks.py that directly tests this new padding and trimming logic,
Fixes #12568
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] Did you read the contributor guideline?
- [x] Did you read our philosophy doc (important for complex PRs)?
- [x] Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [x] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @sayakpaul @yiyixuxu @DN6
thanks for the PR! however, we will not want any of these logic go into qwen transformer would you be interested to how to support this case( not just qwen) from the context parallel hooks https://github.com/huggingface/diffusers/blob/main/src/diffusers/hooks/context_parallel.py#L204
thanks for the PR! however, we will not want any of these logic go into qwen transformer would you be interested to how to support this case( not just qwen) from the context parallel hooks https://github.com/huggingface/diffusers/blob/main/src/diffusers/hooks/context_parallel.py#L204
Hi @yiyixuxu, yes I would be interested to support this change.
Hi @yiyixuxu ,Just wanted to follow up. After looking at the hook implementation as you suggested, I've updated the PR with a new approach that is fully generic and contains all logic within the hooks, with no changes to the transformer.
The solution now involves adding padding in the ContextParallelSplitHook and then trimming it in theContextParallelGatherHook, using the module instance to temporarily store the original sequence length. I've also added a new unit test for this logic in test_hooks.py. Thanks and lmk if you need more changes. I've updated the PR description with the full details.
CC @sayakpaul @DN6
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
Hello @yiyixuxu, I wanted to follow up on this in case you were busy. Thanks.
Hmm, I think we need to wait a bit before https://github.com/huggingface/diffusers/pull/12702 is merged because it is tackling padding too.
Ok got it @sayakpaul , thanks for letting me know.