jax
jax copied to clipboard
Make custom VJP bwd shape/type checking optional.
Would it be possible to make this check introduced here optional? We were previously putting things with deliberately different shapes into this for logging, as an easy way of extracting info through the gradients. Though based on the linked PR, it seems definitely useful in general.
It seems like removing the check doesn't currently break anything.