Option for `EQX_ON_ERROR="off"`
I am using equinox.error_if in library code, particularly in __init__s to check whether or not parameters take on allowed values. It is important that our classes can be instantiated across JIT boundaries, so we must use runtime checking.
Beyond myself, for use in library code it would be very useful to have EQX_ON_ERROR="off" to optionally remove all uses of jax.lax.cond in compiled programs for end-users (for GPU performance). Would Equinox accept a PR with this feature?
So I considered this, but ended up settling on EQX_ON_ERROR=nan instead, which struck me as a safer alternative.
I take your point that this may result in reduced performance - JAX/XLA/etc maybe isn't smart enough to speculatively execute through a lax.cond.
I'd be happy to take a PR on this :)
This makes sense. This and the nan option both seem like they would have their use cases!