pytensor
pytensor copied to clipboard
Add error message to `NonZero` Op telling users it cannot be jitted in JAX
Description
jax mode is missing a dispatch for the NonZero Op.
There's a jnp.nonzero, so it should be easy to do.
you won't be able to jit it, the output has dynamic shape
You had the same question a few months ago with where(single_variable) btw
https://github.com/pymc-devs/pytensor/issues/1062
At least I'm consistent
or stubborn xD
We should explicitly dispatch to Nonzero to issue a message that jax nonzero can't be jitted. So it doesn't show as just NotImplementedError and avoid you coming up again in 5 months time
I updated the title to reflect your suggestion, since the issue as originally written is impossible