jax
jax copied to clipboard
`jnp.var` returns nan if `N-ddof <= 0`
Description:
- Updated jnp.var function to explicitly return np.nan if normalizer is non-positive
- Added a test
Fixed #21330
Looks good – I suspect std will have the same issue. Fix that here as well?
std calls var internally: https://github.com/google/jax/blob/47420a382583d025c606c6349afd3f71fe571aef/jax/_src/numpy/reductions.py#L511
I agree we can add a test for std as well