jax
jax copied to clipboard
Added `correction` arg to jnp.var and jnp.std for array-api compliance
Description:
- Added
correctionarg to jnp.var and jnp.std for array-api compliance - Updated tests
- Addresses https://github.com/google/jax/issues/21088
Thanks for the review, @jakevdp !
Tests are currently failing because of a signature mismatch between the new std, var funcs and their numpy counterparts. This can be fixed by updating extra_params in lax_numpy_test::testWrappedSignaturesMatch to include the new correction arg for both functions. See: https://github.com/google/jax/blob/5e2710c2c28a6f5bc2d6c89cf7148ea254685c30/tests/lax_numpy_test.py#L5970-L5973
@Micky774 yes, right now removing "correction" from unmatched list I have this: Missing entries:
'std': {'np_params': ['a', 'axis', 'dtype', 'out', 'ddof', 'keepdims', 'where', 'correction'], 'jnp_params': ['a', 'axis', 'dtype', 'out', 'correction', 'keepdims', 'where', 'ddof']}
'var': {'np_params': ['a', 'axis', 'dtype', 'out', 'ddof', 'keepdims', 'where', 'correction'], 'jnp_params': ['a', 'axis', 'dtype', 'out', 'correction', 'keepdims', 'where', 'ddof']}
it means that I should reorder the incorrect order I've done in jnp:
- def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
- out: None = None, correction: int | float = 0, keepdims: bool = False, *,
- where: ArrayLike | None = None, ddof: int | DeprecatedArg = DeprecatedArg()) -> Array:
+ def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
+ out: None = None, ddof: int | DeprecatedArg = DeprecatedArg(), keepdims: bool = False, *,
+ where: ArrayLike | None = None, correction: int | float = 0) -> Array:
However, given that numpy does not deprecate ddof arg: https://numpy.org/devdocs/reference/generated/numpy.std.html Does it make sense to deprecate it in jnp, especially when it will be removed there will be once again a failure of signature mismatch ? We may finally wanted just to add new arg ("correction" to jnp.std / jnp.var) similarly to numpy ?
Ah good point, it seems that they do intend to maintain both as valid API (discussion) so we ought to do the same. It will still be fully array API compliant, so indeed let's keep ddof.
Yes, this approach works, but looks like a hack. By the way, numpy in this case does not raise the error:
>>> import numpy as np
>>> np.var(np.array([1.0, 2.0]), ddof=0, correction=0)
np.float64(0.25)
I wonder where exactly the boundary for type hints is passing between jax.numpy and numpy. For example, instead of introducing _zero = _int(0), we could change ddof: int = 0 into ddof: int | None = None and set it internally to zero by default. Docs will be still showing what numpy has:
var(a: 'ArrayLike', axis: 'Axis' = None, dtype: 'DTypeLike | None' = None, out: 'None' = None, ddof: 'int | None' = None, keepdims: 'bool' = False, *, where: 'ArrayLike | None' = None, correction: 'int | float | None' = None) -> 'Ar
ray'
ddof : {int, float}, optional
Means Delta Degrees of Freedom. The divisor used in calculations
is ``N - ddof``, where ``N`` represents the number of elements.
By default `ddof` is zero. See Notes for details about use of `ddof`.
correction : {int, float}, optional
Array API compatible name for the ``ddof`` parameter. Only one of them
can be provided at the same time.
Jake, I'm happy to implement your solution with _zero = _int(0) if there is no way of ddof type hint using None. I'm just trying to figure out the project conventions and limitations on the proposed changes.
Let's stick with ddof=0 as a default. Simplest is best here I think
Let's stick with
ddof=0as a default. Simplest is best here I think
Sounds good, if I understand correctly your comment that we keep the following check:
if correction is None:
correction = ddof
elif not isinstance(ddof, int) or ddof != 0:
raise ValueError("ddof and correction can't be provided simultaneously.")
and do not change ddof.
I updated the test according to your review