pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Revert patch for index underflow after location 127 when `mode=JAX`

Open jessegrabowski opened this issue 1 year ago • 9 comments

Describe the issue:

It seems the JAX linker downcasts index constants to uint8? mode=None and mode="NUMBA" work as expected. Declaring an index variable (i = pt.lscalar('i'); z = x[i]) also works as expected.

Reproducable code example:

import pytensor
import pytensor.tensor as pt
import numpy as np

x = pt.dvector('x')
z1 = x[127]
z2 = x[128]
f = pytensor.function([x], [z1, z2], mode='JAX')

f(np.arange(200))
# out: [Array(127., dtype=float64), Array(0., dtype=float64)]

Error message:

No response

PyTensor version information:

Pytensor 2.13.1

Context for the issue:

No response

jessegrabowski avatar Jul 24 '23 11:07 jessegrabowski

Now that I didn't expect xD

ricardoV94 avatar Jul 24 '23 14:07 ricardoV94

Seems to be a JAX bug?

import jax
import jax.numpy as jnp
import numpy as np

def subtensor(x):
    return x[np.array(128, dtype="uint8")]

subtensor(jnp.arange(200))  # Array(199, dtype=int32)
def subtensor(x):
    return x[np.array(128, dtype="uint16")]

subtensor(jnp.arange(200))  # Array(128, dtype=int32)

Opened an issue: https://github.com/google/jax/issues/16836

ricardoV94 avatar Jul 25 '23 13:07 ricardoV94

On our side we can exclude the rewrite local_uint_constant_indices on JAX mode here: https://github.com/pymc-devs/pytensor/blob/6b189ee3cc3df9a1dda0258da4642c4178f2845b/pytensor/compile/mode.py#L450-L456

JAX = Mode(
    JAXLinker(),
    RewriteDatabaseQuery(
        include=["fast_run", "jax"],
        exclude=["cxx_only", "BlasOpt", "fusion", "inplace", "local_uint_constant_indices"],
    ),
)

ricardoV94 avatar Jul 25 '23 13:07 ricardoV94

I think I was wrong though, it's casting to int8 not uint8 (int8 goes to 128. uint8 should go to 255). Wouldn't dropping the minus sign break negative indexes?

jessegrabowski avatar Jul 25 '23 13:07 jessegrabowski

I stepped on the debugger and it was keeping the same type as the one reported in the dprint of the compiled function. I didn't see any implicit casting, only the one done explicitly by the rewrite mentioned above.

ricardoV94 avatar Jul 25 '23 14:07 ricardoV94

import pytensor
import pytensor.tensor as pt

x = pt.dvector('x')
z = x[128]

pytensor.dprint(z, print_type=True)

pytensor.config.optimizer_verbose=True
f = pytensor.function([x], z, mode='JAX')

pytensor.dprint(f, print_type=True)
Subtensor{i} [id A] <Scalar(float64, shape=())>
 ├─ x [id B] <Vector(float64, shape=(?,))>
 └─ 128 [id C] <int64>

rewriting: rewrite local_uint_constant_indices replaces Subtensor{i}.0 of Subtensor{i}(x, 128) with Subtensor{i}.0 of Subtensor{i}(x, 128)

DeepCopyOp [id A] <Scalar(float64, shape=())> 1
 └─ Subtensor{i} [id B] <Scalar(float64, shape=())> 0
    ├─ x [id C] <Vector(float64, shape=(?,))>
    └─ 128 [id D] <uint8>

ricardoV94 avatar Jul 25 '23 14:07 ricardoV94

I guess I'm asking why we rewrite to uint8 then? Isn't it needlessly restrictive?

jessegrabowski avatar Jul 25 '23 14:07 jessegrabowski

The rewrite checks the minimum type that can can still fit the indexing, because in most systems it's faster that way. It was introduced in https://github.com/aesara-devs/aesara/pull/1150

ricardoV94 avatar Jul 25 '23 14:07 ricardoV94

This was fixed in JAX, but may be worth waiting a while longer before reverting the patch

ricardoV94 avatar Aug 24 '23 11:08 ricardoV94