xla icon indicating copy to clipboard operation
xla copied to clipboard

`torchax` fails on a simple matrix slicing example.

Open joaospinto opened this issue 3 months ago • 3 comments

🐛 Bug

torchax fails on a simple matrix slicing example.

To Reproduce

Here is the code to repro:

import torch
import torchax as tx
import torchax.export

import jax
import jax.numpy as jnp

import sys


tx.enable_globally()


def f(M, p):
    return M[torch.arange(M.shape[0]), p]


class Wrapper(torch.nn.Module):
    def forward(self, M, p):
        return f(M, p)


def main():
    torch_outputs = Wrapper()(torch.arange(4).reshape([2, 2]), torch.tensor([1, 0]))

    print(f"{torch_outputs=}")

    M = jnp.arange(4).reshape([2, 2])
    p = jnp.array([1, 0])
    sample_input = (M, p)

    weights, jfunc = tx.extract_jax(Wrapper())

    def jfunc_inlined(args):
        return jfunc(weights, args)

    jitted = jax.jit(jfunc_inlined)

    jax_outputs = jitted(sample_input)

    print(f"{jax_outputs=}")


if __name__ == "__main__":
    main()

If you run it, you'll get:

AssertionError: Expect a Tensor or a View but got <class 'torch.Tensor'>; usually this means there is a mixed math between XLATensor and torch.Tensor

Expected behavior

jax_outputs should be computed without errors and match the torch_outputs value.

Environment

einops==0.8.1
filelock==3.19.1
fsspec==2025.9.0
jax==0.7.1
jaxlib==0.7.1
Jinja2==3.1.6
MarkupSafe==3.0.2
ml_dtypes==0.5.3
mpmath==1.3.0
networkx==3.5
numpy==2.3.3
opt_einsum==3.4.0
scipy==1.16.2
setuptools==80.9.0
sympy==1.14.0
torch==2.8.0
torchax==0.0.7
typing_extensions==4.15.0

Additional context

No additional context; should be pretty clear.

joaospinto avatar Sep 14 '25 17:09 joaospinto

Doing torch.arange(M.shape[0]).to(M.device) instead fixes it, but I think this shouldn't be necessary. Hopefully you can patch torchax accordingly. Thanks!

joaospinto avatar Sep 16 '25 01:09 joaospinto

@qihqi Is this expected behavior?

ysiraichi avatar Sep 16 '25 13:09 ysiraichi

Yes, the behavior is valid. This is a bug that we should fix

qihqi avatar Sep 22 '25 23:09 qihqi