pytensor copied to clipboard
Revert patch for index underflow after location 127 when `mode=JAX`
Describe the issue:
It seems the JAX linker downcasts index constants to uint8
? mode=None
and mode="NUMBA"
work as expected. Declaring an index variable (i = pt.lscalar('i'); z = x[i]
) also works as expected.
Reproducable code example:
import pytensor
import pytensor.tensor as pt
import numpy as np
x = pt.dvector('x')
z1 = x[127]
z2 = x[128]
f = pytensor.function([x], [z1, z2], mode='JAX')
# out: [Array(127., dtype=float64), Array(0., dtype=float64)]
Error message:
No response
PyTensor version information:
Pytensor 2.13.1
Context for the issue:
No response
Now that I didn't expect xD
Seems to be a JAX bug?
import jax
import jax.numpy as jnp
import numpy as np
def subtensor(x):
return x[np.array(128, dtype="uint8")]
subtensor(jnp.arange(200)) # Array(199, dtype=int32)
def subtensor(x):
return x[np.array(128, dtype="uint16")]
subtensor(jnp.arange(200)) # Array(128, dtype=int32)
Opened an issue:
On our side we can exclude the rewrite local_uint_constant_indices
on JAX mode here:
JAX = Mode(
include=["fast_run", "jax"],
exclude=["cxx_only", "BlasOpt", "fusion", "inplace", "local_uint_constant_indices"],
I think I was wrong though, it's casting to int8
not uint8
(int8 goes to 128. uint8 should go to 255). Wouldn't dropping the minus sign break negative indexes?
I stepped on the debugger and it was keeping the same type as the one reported in the dprint
of the compiled function. I didn't see any implicit casting, only the one done explicitly by the rewrite mentioned above.
import pytensor
import pytensor.tensor as pt
x = pt.dvector('x')
z = x[128]
pytensor.dprint(z, print_type=True)
f = pytensor.function([x], z, mode='JAX')
pytensor.dprint(f, print_type=True)
Subtensor{i} [id A] <Scalar(float64, shape=())>
├─ x [id B] <Vector(float64, shape=(?,))>
└─ 128 [id C] <int64>
rewriting: rewrite local_uint_constant_indices replaces Subtensor{i}.0 of Subtensor{i}(x, 128) with Subtensor{i}.0 of Subtensor{i}(x, 128)
DeepCopyOp [id A] <Scalar(float64, shape=())> 1
└─ Subtensor{i} [id B] <Scalar(float64, shape=())> 0
├─ x [id C] <Vector(float64, shape=(?,))>
└─ 128 [id D] <uint8>
I guess I'm asking why we rewrite to uint8 then? Isn't it needlessly restrictive?
The rewrite checks the minimum type that can can still fit the indexing, because in most systems it's faster that way. It was introduced in
This was fixed in JAX, but may be worth waiting a while longer before reverting the patch