`all_to_all` operation generates invalid HLO
🐛 Bug
all_to_all operation generates invalid HLO that fails verification with error: RET_CHECK failure hlo->operand_count() == split_count. The generated HLO all-to-all instruction is missing required attributes (split_dimension, concat_dimension, split_count) and has mismatched operand count vs split count.
To Reproduce
Steps to reproduce the behavior:
- Create a simple tensor on XLA device
- Call
xm.all_to_all()with split parameters - Try to execute the tensor (e.g.,
.cpu()) Minimal reproduction code:
import os
os.environ["PJRT_DEVICE"] = "CPU"
import torch
import torch_xla.core.xla_model as xm
# Create tensor on XLA device
device = xm.xla_device()
value = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.int32, device=device)
# Call all_to_all - this generates invalid HLO
result = xm.all_to_all(
value,
split_dimension=0,
concat_dimension=0,
split_count=2)
# Force execution triggers the error
print(result.cpu())
Error message:
RET_CHECK failure (external/xla/xla/service/hlo_verifier.cc:566) hlo->operand_count() == split_count
Expected behavior
The all_to_all operation should generate valid HLO that passes verification and executes successfully. The HLO instruction should include proper split_dimension, concat_dimension, and split_count attributes that match the operand structure.
Environment
- Reproducible on XLA backend [CPU/TPU]: CPU/NEURON
- torch_xla version: 2.5.0+
Additional context
This affects both CPU and Neuron backends. The bug seems in the HLO generation layer where TokenHandler::GetInput() modifies the input tensor, causing PyTorch XLA to create multiple operands without properly setting the corresponding HLO attributes.
Thank you for filing this issue.
- I wonder whether this was caused by an OpenXLA pin update. Were you able to identify when this started happening?
- Did you try with
CPU_NUM_DEVICESgreater than 1?
cc @bhavya01
I wonder whether this was caused by an OpenXLA pin update. Were you able to identify when this started happening?
Looks unlikely as I was able to replicate with pt_xla 2.6.1 and some user saw issue with all_to_all in 2.5+ as well, but did not narrow down to which specific commit yet.
Did you try with CPU_NUM_DEVICES greater than 1?
Yes, fails.