jax
jax copied to clipboard
Improve typing of `jax.jit`
- Fix for #23719
Thanks for the contribution! See #14688 for a past attempt at this, the things that went wrong, and some discussion of the considerations around a change like this. In particular, PyType was a blocker in the past, and we'll have to check whether that's still the case.
Thanks @jakevdp for the prompt response. After this change, the MyPy type checker in the pre-commit hook now becomes aware of typing issues that were already present in the codebase. Would it normally be expected for this PR to also include fixes for these typing issues as well? Or should that be done in a separate PR?
Yeah, newly discovered issues need to be fixed or silenced, because otherwise merging the PR will break our CI.
Could you given an example of the issue you discovered, please?
I just rebased the PR to only improve the typing of jax.jit. I'll make another PR to add type hints to jax.eval_shapes, JitWrapped.eval_shapes and JitWrapped.trace.
Here are the typing issues that the MyPy pre-commit hook now sees within the codebase, now that the signature of the callable isn't dropped by jax.jit:
jax/_src/numpy/ufunc_api.py:177: error: Argument 1 to "__call__" of "Wrapped" has incompatible type "*tuple[Array | ndarray[Any, Any] | bool_ | number[Any] | bool | int | float | complex, ...]"; expected "ufunc" [arg-type]
jax/_src/numpy/array_methods.py:171: error: Argument "axis" to "__call__" of "Wrapped" has incompatible type "int | Sequence[int] | None"; expected "int | None" [arg-type]
jax/_src/numpy/array_methods.py:179: error: Argument "axis" to "__call__" of "Wrapped" has incompatible type "int | Sequence[int] | None"; expected "int | None" [arg-type]
jax/_src/lax/linalg.py:1665: error: Argument 1 to "__call__" of "Wrapped" has incompatible type "Array | ndarray[Any, Any] | bool_ | number[Any] | bool | int | float | complex"; expected "Array" [arg-type]
jax/_src/lax/linalg.py:1665: error: Argument 2 to "__call__" of "Wrapped" has incompatible type "Array | ndarray[Any, Any] | bool_ | number[Any] | bool | int | float | complex"; expected "Array" [arg-type]
jax/_src/lax/linalg.py:1665: error: Argument 3 to "__call__" of "Wrapped" has incompatible type "Array | ndarray[Any, Any] | bool_ | number[Any] | bool | int | float | complex"; expected "Array" [arg-type]
jax/_src/numpy/lax_numpy.py:1127: error: Argument 2 to "__call__" of "Wrapped" has incompatible type "int | Sequence[int] | None"; expected "int | tuple[int, ...] | None" [arg-type]
jax/_src/numpy/lax_numpy.py:1985: error: Argument 2 to "__call__" of "Wrapped" has incompatible type "tuple[int, ...] | None"; expected "tuple[int, ...]" [arg-type]
jax/_src/numpy/lax_numpy.py:7228: error: Unused "type: ignore[arg-type, operator]" comment [unused-ignore]
jax/_src/third_party/scipy/special.py:275: error: Incompatible types in assignment (expression has type "Array", variable has type "float") [assignment]
jax/_src/third_party/scipy/special.py:276: error: Incompatible types in assignment (expression has type "Array", variable has type "float") [assignment]
jax/_src/scipy/linalg.py:222: error: Overloaded function implementation does not accept all possible arguments of signature 1 [misc]
jax/_src/scipy/linalg.py:222: error: Overloaded function implementation does not accept all possible arguments of signature 2 [misc]
jax/_src/scipy/linalg.py:222: error: Overloaded function implementation does not accept all possible arguments of signature 3 [misc]
jax/_src/scipy/linalg.py:548: error: Argument 1 to "__call__" of "Wrapped" has incompatible type "Array | ndarray[Any, Any] | bool_ | number[Any] | bool | int | float | complex"; expected "Array" [arg-type]
jax/_src/scipy/cluster/vq.py:73: error: No overload variant of "getitem" matches argument types "Array", "Array" [call-overload]
jax/_src/scipy/cluster/vq.py:73: note: Possible overload variants:
jax/_src/scipy/cluster/vq.py:73: note: def [_T] getitem(Sequence[_T], slice, /) -> Sequence[_T]
jax/_src/scipy/cluster/vq.py:73: note: def [_K, _V] getitem(SupportsGetItem[_K, _V], _K, /) -> _V
jax/_src/scipy/special.py:1808: error: Argument 5 to "__call__" of "Wrapped" has incompatible type "int | None"; expected "int" [arg-type]
Found 17 errors in 8 files (checked 511 source files)
What is the typical way that such typing errors are silenced in this project? Do you prefer # type: ignore comments or typing.cast ?
Also, should I leave a comment referencing a new issue for the typing errors that I silence?
I am pretty sure this will break a lot of other targets internally too making this change very difficult to land.
Yeah, #14688 was eventually blocked by the fact that pytype doesn't properly suppot ParamSpec. I'm not sure whether that's changed in the meantime.
Some of the errors reflect that the jit annotation in this PR is not correct. For example this one:
jax/_src/lax/linalg.py:1665: error: Argument 1 to "__call__" of "Wrapped" has incompatible type "Array | ndarray[Any, Any] | bool_ | number[Any] | bool | int | float | complex"; expected "Array" [arg-type]
It basically comes from something that looks like this:
@jit
def _lu_solve(x: Array):
...
def lu_solve(x: ArrayLike):
return _lu_solve(x) # <- type error, because ArrayLike is not Array
However, when you wrap a function with jit, all ArrayLike inputs are implicitly converted to Array before being passed to the wrapped function. So in some senses this annotation is correct, and the mypy error is a false-positive due to the new jit annotation being stricter than it needs to be.
What do you think?
re: @superbobry, @yashk2810 - I silenced the new typing errors that mypy raised in the pre-commit hook.
re: @jakevdp
However, when you wrap a function with jit, all ArrayLike inputs are implicitly converted to Array before being passed to the wrapped function. So in some senses this annotation is correct, and the mypy error is a false-positive due to the new jit annotation being stricter than it needs to be.
I agree. In my view, this only encourages internal jax source to be more explicit, by not depending on this implicit conversion from ArrayLike to Array.
re: @yashk2810
I am pretty sure this will break a lot of other targets internally too making this change very difficult to land.
Are you saying that this tiny little PR would also improve other downstream projects at Google? :star_struck: :stuck_out_tongue:
Haha, I wouldn't say improve but it will break a lot of stuff and I don't think we have the bandwidth to fix all those projects. Hence landing this is very hard IRL.
I think, setting aside caveats about how hard this might be to land, this is generally a change we want, and one we've been hoping to add for a long time. Initially we were blocked by the lack of ParamSpec in the type system, then we were blocked by the lack of support for ParamSpec in the mypy implementation, and then in the pytype implementation. If the pytype blocker is now fixed, we can do the work to land this (basically adding # ignore statements in any place that it breaks). But I think pytype may still be a blocker, as it was for #14688 six months ago.
The relevant issue is https://github.com/google/pytype/issues/1471, which is still open.
re: @yashk2810 I guess it's a matter of perspective. In my view, revealing typing errors / encouraging code to be more explicit is an improvement.
re: @yashk2810 @jakevdp I'd be very happy to help and put in the time required in order to fix (or at the very least silence) any such typing errors that are revealed as a result of this PR in other projects. Do you by chance have a kind of list of these public-facing, jax-based projects that could potentially have their CI fail as a result of this change? If not, I can also try to gather such a list myself.
Just to be clear, the pytype blocker is not about revealing existing errors, it's about the fact that pytype fails loudly and completely when it sees covariant=True. If that hasn't changed, then I'm afraid we can't do much else here.
With respect to https://github.com/google/pytype/issues/1471, would this change be easier to merge if the output TypeVar were not marked as covariant? :thinking:
Edit: I'll double-check, but I think that the output var being invariant might cause other issues (which would then be due to the the annotation not being 100% correct)
Yes, if this didn't use covariant typevars, it would be easier to merge. But my understanding from #14688 was that covariant typevars are required in order to correctly annotate jit.
@jakevdp pytype treats all type variables as covariant IIRC, so maybe we can just suppress the warning for that particular type var?
Re @superbobry @jakevdp :
I added some # pytype: disable=not-supported-yet over the typevar definitions. If my understanding of pytype is correct, it will now simply drop the covariant arg, and treat those as regular typevars.
@lebrice : I'd be very happy to help and put in the time required in order to fix (or at the very least silence) any such typing errors that are revealed as a result of this PR in other projects. Do you by chance have a kind of list of these public-facing, jax-based projects that could potentially have their CI fail as a result of this change? If not, I can also try to gather such a list myself.
I'll start with Flax, since that seems like the most obvious downstream project from my perspective. I'm able to get their Pytype-related CI steps to run without error, at least locally.
Pulling in to run internal pytype tests
I'm curious: If this ends up getting merged, will the Google-ML-Automation bot include my github username in the final commit? Or would someone on the inside need to add a Co-authored-by: Fabrice Normandin <[email protected]> in the commit message?
I'm curious: If this ends up getting merged, will the Google-ML-Automation bot include my github username in the final commit? Or would someone on the inside need to add a
Co-authored-by: Fabrice Normandin <[email protected]>in the commit message?
If this is merged, your actual unmodified commit would be added to the JAX source tree.
HI @superbobry , would it help if I rebased this off of the most recent version of master?
There might be new places where # type: ignore comments should be added. What do you think?
Apologies for the silence. We can try, but it looks like pytype might need some fixing to accept the declarations in this PR.
Hi @jakevdp @superbobry , I just rebased this again. Is there by any chance a Pytype GitHub Issue that I could take a look at, if I wanted to take a shot at fixing something, just out of curiosity? In the meantime I guess I'll just keep rebasing this once in a while.
Thanks @lebrice! The relevant issue is google/pytype#1471, as Jake pointed out above. I think it is possible to fix, but it will be a massive change, since internal Google code definitely relies on this pytype behavior (even though it is obviously wrong).