Fix None process group appended in _fuse_input_dist_splits
Summary: We identified this potential bug during debugging issue reported in https://fb.workplace.com/groups/755371733754414/permalink/833999072850393/
Fixed a bug in _fuse_input_dist_splits where names with no valid process
group (pg=None) were being added to names_per_pg[None]. This would cause
issues downstream when trying to create FusedKJTListSplitsAwaitable with
a None process group.
The issue occurred when:
- A request is of type
KJTListSplitsAwaitable - None of its awaitables are of type
KJTSplitsAllToAllMeta - This leaves
pg = None(line 207) - The name was still appended to
names_per_pg[None](line 213)
The fix adds a check to only append names when pg is not None, ensuring
that only requests with valid process groups are included in the fused
operations.
Why this matters:
- Prevents passing
pg=NonetoFusedKJTListSplitsAwaitable(line 232) - Ensures only valid distributed operations are fused together
- Avoids potential runtime errors or undefined behavior
Differential Revision: D87110878
@YongzhongYang has exported this pull request. If you are a Meta employee, you can view the originating Diff in D87110878.