jax icon indicating copy to clipboard operation
jax copied to clipboard

Added `correction` arg to jnp.var and jnp.std for array-api compliance

Open vfdev-5 opened this issue 1 year ago • 6 comments

Description:

  • Added correction arg to jnp.var and jnp.std for array-api compliance
  • Updated tests
  • Addresses https://github.com/google/jax/issues/21088

vfdev-5 avatar May 16 '24 15:05 vfdev-5

Thanks for the review, @jakevdp !

vfdev-5 avatar May 16 '24 16:05 vfdev-5

Tests are currently failing because of a signature mismatch between the new std, var funcs and their numpy counterparts. This can be fixed by updating extra_params in lax_numpy_test::testWrappedSignaturesMatch to include the new correction arg for both functions. See: https://github.com/google/jax/blob/5e2710c2c28a6f5bc2d6c89cf7148ea254685c30/tests/lax_numpy_test.py#L5970-L5973

Micky774 avatar May 17 '24 10:05 Micky774

@Micky774 yes, right now removing "correction" from unmatched list I have this: Missing entries:

'std': {'np_params': ['a', 'axis', 'dtype', 'out', 'ddof', 'keepdims', 'where', 'correction'], 'jnp_params': ['a', 'axis', 'dtype', 'out', 'correction', 'keepdims', 'where', 'ddof']}
'var': {'np_params': ['a', 'axis', 'dtype', 'out', 'ddof', 'keepdims', 'where', 'correction'], 'jnp_params': ['a', 'axis', 'dtype', 'out', 'correction', 'keepdims', 'where', 'ddof']}

it means that I should reorder the incorrect order I've done in jnp:

- def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
-        out: None = None, correction: int | float = 0, keepdims: bool = False, *,
-        where: ArrayLike | None = None, ddof: int | DeprecatedArg = DeprecatedArg()) -> Array:

+ def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
+        out: None = None, ddof: int | DeprecatedArg = DeprecatedArg(), keepdims: bool = False, *,
+        where: ArrayLike | None = None, correction: int | float = 0) -> Array:

However, given that numpy does not deprecate ddof arg: https://numpy.org/devdocs/reference/generated/numpy.std.html Does it make sense to deprecate it in jnp, especially when it will be removed there will be once again a failure of signature mismatch ? We may finally wanted just to add new arg ("correction" to jnp.std / jnp.var) similarly to numpy ?

vfdev-5 avatar May 17 '24 11:05 vfdev-5

Ah good point, it seems that they do intend to maintain both as valid API (discussion) so we ought to do the same. It will still be fully array API compliant, so indeed let's keep ddof.

Micky774 avatar May 17 '24 11:05 Micky774

Yes, this approach works, but looks like a hack. By the way, numpy in this case does not raise the error:

>>> import numpy as np                                  
>>> np.var(np.array([1.0, 2.0]), ddof=0, correction=0)                                                               
np.float64(0.25)  

I wonder where exactly the boundary for type hints is passing between jax.numpy and numpy. For example, instead of introducing _zero = _int(0), we could change ddof: int = 0 into ddof: int | None = None and set it internally to zero by default. Docs will be still showing what numpy has:

var(a: 'ArrayLike', axis: 'Axis' = None, dtype: 'DTypeLike | None' = None, out: 'None' = None, ddof: 'int | None' = None, keepdims: 'bool' = False, *, where: 'ArrayLike | None' = None, correction: 'int | float | None' = None) -> 'Ar
ray'

    ddof : {int, float}, optional
        Means Delta Degrees of Freedom.  The divisor used in calculations
        is ``N - ddof``, where ``N`` represents the number of elements.
        By default `ddof` is zero. See Notes for details about use of `ddof`.

    correction : {int, float}, optional
        Array API compatible name for the ``ddof`` parameter. Only one of them
        can be provided at the same time.

Jake, I'm happy to implement your solution with _zero = _int(0) if there is no way of ddof type hint using None. I'm just trying to figure out the project conventions and limitations on the proposed changes.

vfdev-5 avatar May 22 '24 22:05 vfdev-5

Let's stick with ddof=0 as a default. Simplest is best here I think

jakevdp avatar May 22 '24 22:05 jakevdp

Let's stick with ddof=0 as a default. Simplest is best here I think

Sounds good, if I understand correctly your comment that we keep the following check:

if correction is None:
  correction = ddof
elif not isinstance(ddof, int) or ddof != 0:
  raise ValueError("ddof and correction can't be provided simultaneously.")

and do not change ddof.

I updated the test according to your review

vfdev-5 avatar May 24 '24 13:05 vfdev-5