pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Add error message to `NonZero` Op telling users it cannot be jitted in JAX

Open jessegrabowski opened this issue 5 months ago • 7 comments

Description

jax mode is missing a dispatch for the NonZero Op.

There's a jnp.nonzero, so it should be easy to do.

jessegrabowski avatar Jul 01 '25 03:07 jessegrabowski

you won't be able to jit it, the output has dynamic shape

ricardoV94 avatar Jul 01 '25 06:07 ricardoV94

You had the same question a few months ago with where(single_variable) btw

ricardoV94 avatar Jul 02 '25 06:07 ricardoV94

https://github.com/pymc-devs/pytensor/issues/1062

ricardoV94 avatar Jul 02 '25 06:07 ricardoV94

At least I'm consistent

jessegrabowski avatar Jul 02 '25 06:07 jessegrabowski

or stubborn xD

ricardoV94 avatar Jul 02 '25 07:07 ricardoV94

We should explicitly dispatch to Nonzero to issue a message that jax nonzero can't be jitted. So it doesn't show as just NotImplementedError and avoid you coming up again in 5 months time

ricardoV94 avatar Jul 04 '25 09:07 ricardoV94

I updated the title to reflect your suggestion, since the issue as originally written is impossible

jessegrabowski avatar Jul 10 '25 02:07 jessegrabowski