xla
xla copied to clipboard
`torchax` fails on a simple matrix slicing example.
🐛 Bug
torchax fails on a simple matrix slicing example.
To Reproduce
Here is the code to repro:
import torch
import torchax as tx
import torchax.export
import jax
import jax.numpy as jnp
import sys
tx.enable_globally()
def f(M, p):
return M[torch.arange(M.shape[0]), p]
class Wrapper(torch.nn.Module):
def forward(self, M, p):
return f(M, p)
def main():
torch_outputs = Wrapper()(torch.arange(4).reshape([2, 2]), torch.tensor([1, 0]))
print(f"{torch_outputs=}")
M = jnp.arange(4).reshape([2, 2])
p = jnp.array([1, 0])
sample_input = (M, p)
weights, jfunc = tx.extract_jax(Wrapper())
def jfunc_inlined(args):
return jfunc(weights, args)
jitted = jax.jit(jfunc_inlined)
jax_outputs = jitted(sample_input)
print(f"{jax_outputs=}")
if __name__ == "__main__":
main()
If you run it, you'll get:
AssertionError: Expect a Tensor or a View but got <class 'torch.Tensor'>; usually this means there is a mixed math between XLATensor and torch.Tensor
Expected behavior
jax_outputs should be computed without errors and match the torch_outputs value.
Environment
einops==0.8.1
filelock==3.19.1
fsspec==2025.9.0
jax==0.7.1
jaxlib==0.7.1
Jinja2==3.1.6
MarkupSafe==3.0.2
ml_dtypes==0.5.3
mpmath==1.3.0
networkx==3.5
numpy==2.3.3
opt_einsum==3.4.0
scipy==1.16.2
setuptools==80.9.0
sympy==1.14.0
torch==2.8.0
torchax==0.0.7
typing_extensions==4.15.0
Additional context
No additional context; should be pretty clear.
Doing torch.arange(M.shape[0]).to(M.device) instead fixes it, but I think this shouldn't be necessary. Hopefully you can patch torchax accordingly. Thanks!
@qihqi Is this expected behavior?
Yes, the behavior is valid. This is a bug that we should fix