jax icon indicating copy to clipboard operation
jax copied to clipboard

Improve typing of `jax.jit`

Open lebrice opened this issue 1 year ago • 23 comments

  • Fix for #23719

lebrice avatar Sep 18 '24 13:09 lebrice

Thanks for the contribution! See #14688 for a past attempt at this, the things that went wrong, and some discussion of the considerations around a change like this. In particular, PyType was a blocker in the past, and we'll have to check whether that's still the case.

jakevdp avatar Sep 18 '24 13:09 jakevdp

Thanks @jakevdp for the prompt response. After this change, the MyPy type checker in the pre-commit hook now becomes aware of typing issues that were already present in the codebase. Would it normally be expected for this PR to also include fixes for these typing issues as well? Or should that be done in a separate PR?

lebrice avatar Sep 18 '24 14:09 lebrice

Yeah, newly discovered issues need to be fixed or silenced, because otherwise merging the PR will break our CI.

Could you given an example of the issue you discovered, please?

superbobry avatar Sep 18 '24 15:09 superbobry

I just rebased the PR to only improve the typing of jax.jit. I'll make another PR to add type hints to jax.eval_shapes, JitWrapped.eval_shapes and JitWrapped.trace.

Here are the typing issues that the MyPy pre-commit hook now sees within the codebase, now that the signature of the callable isn't dropped by jax.jit:

jax/_src/numpy/ufunc_api.py:177: error: Argument 1 to "__call__" of "Wrapped" has incompatible type "*tuple[Array | ndarray[Any, Any] | bool_ | number[Any] | bool | int | float | complex, ...]"; expected "ufunc"  [arg-type]
jax/_src/numpy/array_methods.py:171: error: Argument "axis" to "__call__" of "Wrapped" has incompatible type "int | Sequence[int] | None"; expected "int | None"  [arg-type]
jax/_src/numpy/array_methods.py:179: error: Argument "axis" to "__call__" of "Wrapped" has incompatible type "int | Sequence[int] | None"; expected "int | None"  [arg-type]
jax/_src/lax/linalg.py:1665: error: Argument 1 to "__call__" of "Wrapped" has incompatible type "Array | ndarray[Any, Any] | bool_ | number[Any] | bool | int | float | complex"; expected "Array"  [arg-type]
jax/_src/lax/linalg.py:1665: error: Argument 2 to "__call__" of "Wrapped" has incompatible type "Array | ndarray[Any, Any] | bool_ | number[Any] | bool | int | float | complex"; expected "Array"  [arg-type]
jax/_src/lax/linalg.py:1665: error: Argument 3 to "__call__" of "Wrapped" has incompatible type "Array | ndarray[Any, Any] | bool_ | number[Any] | bool | int | float | complex"; expected "Array"  [arg-type]
jax/_src/numpy/lax_numpy.py:1127: error: Argument 2 to "__call__" of "Wrapped" has incompatible type "int | Sequence[int] | None"; expected "int | tuple[int, ...] | None"  [arg-type]
jax/_src/numpy/lax_numpy.py:1985: error: Argument 2 to "__call__" of "Wrapped" has incompatible type "tuple[int, ...] | None"; expected "tuple[int, ...]"  [arg-type]
jax/_src/numpy/lax_numpy.py:7228: error: Unused "type: ignore[arg-type, operator]" comment  [unused-ignore]
jax/_src/third_party/scipy/special.py:275: error: Incompatible types in assignment (expression has type "Array", variable has type "float")  [assignment]
jax/_src/third_party/scipy/special.py:276: error: Incompatible types in assignment (expression has type "Array", variable has type "float")  [assignment]
jax/_src/scipy/linalg.py:222: error: Overloaded function implementation does not accept all possible arguments of signature 1  [misc]
jax/_src/scipy/linalg.py:222: error: Overloaded function implementation does not accept all possible arguments of signature 2  [misc]
jax/_src/scipy/linalg.py:222: error: Overloaded function implementation does not accept all possible arguments of signature 3  [misc]
jax/_src/scipy/linalg.py:548: error: Argument 1 to "__call__" of "Wrapped" has incompatible type "Array | ndarray[Any, Any] | bool_ | number[Any] | bool | int | float | complex"; expected "Array"  [arg-type]
jax/_src/scipy/cluster/vq.py:73: error: No overload variant of "getitem" matches argument types "Array", "Array"  [call-overload]
jax/_src/scipy/cluster/vq.py:73: note: Possible overload variants:
jax/_src/scipy/cluster/vq.py:73: note:     def [_T] getitem(Sequence[_T], slice, /) -> Sequence[_T]
jax/_src/scipy/cluster/vq.py:73: note:     def [_K, _V] getitem(SupportsGetItem[_K, _V], _K, /) -> _V
jax/_src/scipy/special.py:1808: error: Argument 5 to "__call__" of "Wrapped" has incompatible type "int | None"; expected "int"  [arg-type]
Found 17 errors in 8 files (checked 511 source files)

What is the typical way that such typing errors are silenced in this project? Do you prefer # type: ignore comments or typing.cast ? Also, should I leave a comment referencing a new issue for the typing errors that I silence?

lebrice avatar Sep 18 '24 15:09 lebrice

I am pretty sure this will break a lot of other targets internally too making this change very difficult to land.

yashk2810 avatar Sep 18 '24 15:09 yashk2810

Yeah, #14688 was eventually blocked by the fact that pytype doesn't properly suppot ParamSpec. I'm not sure whether that's changed in the meantime.

jakevdp avatar Sep 18 '24 16:09 jakevdp

Some of the errors reflect that the jit annotation in this PR is not correct. For example this one:

jax/_src/lax/linalg.py:1665: error: Argument 1 to "__call__" of "Wrapped" has incompatible type "Array | ndarray[Any, Any] | bool_ | number[Any] | bool | int | float | complex"; expected "Array"  [arg-type]

It basically comes from something that looks like this:

@jit
def _lu_solve(x: Array):
  ...

def lu_solve(x: ArrayLike):
  return _lu_solve(x)  # <- type error, because ArrayLike is not Array

However, when you wrap a function with jit, all ArrayLike inputs are implicitly converted to Array before being passed to the wrapped function. So in some senses this annotation is correct, and the mypy error is a false-positive due to the new jit annotation being stricter than it needs to be.

What do you think?

jakevdp avatar Sep 18 '24 16:09 jakevdp

re: @superbobry, @yashk2810 - I silenced the new typing errors that mypy raised in the pre-commit hook.

re: @jakevdp

However, when you wrap a function with jit, all ArrayLike inputs are implicitly converted to Array before being passed to the wrapped function. So in some senses this annotation is correct, and the mypy error is a false-positive due to the new jit annotation being stricter than it needs to be.

I agree. In my view, this only encourages internal jax source to be more explicit, by not depending on this implicit conversion from ArrayLike to Array.

lebrice avatar Sep 18 '24 17:09 lebrice

re: @yashk2810

I am pretty sure this will break a lot of other targets internally too making this change very difficult to land.

Are you saying that this tiny little PR would also improve other downstream projects at Google? :star_struck: :stuck_out_tongue:

lebrice avatar Sep 18 '24 17:09 lebrice

Haha, I wouldn't say improve but it will break a lot of stuff and I don't think we have the bandwidth to fix all those projects. Hence landing this is very hard IRL.

yashk2810 avatar Sep 18 '24 17:09 yashk2810

I think, setting aside caveats about how hard this might be to land, this is generally a change we want, and one we've been hoping to add for a long time. Initially we were blocked by the lack of ParamSpec in the type system, then we were blocked by the lack of support for ParamSpec in the mypy implementation, and then in the pytype implementation. If the pytype blocker is now fixed, we can do the work to land this (basically adding # ignore statements in any place that it breaks). But I think pytype may still be a blocker, as it was for #14688 six months ago.

The relevant issue is https://github.com/google/pytype/issues/1471, which is still open.

jakevdp avatar Sep 18 '24 17:09 jakevdp

re: @yashk2810 I guess it's a matter of perspective. In my view, revealing typing errors / encouraging code to be more explicit is an improvement.

re: @yashk2810 @jakevdp I'd be very happy to help and put in the time required in order to fix (or at the very least silence) any such typing errors that are revealed as a result of this PR in other projects. Do you by chance have a kind of list of these public-facing, jax-based projects that could potentially have their CI fail as a result of this change? If not, I can also try to gather such a list myself.

lebrice avatar Sep 18 '24 17:09 lebrice

Just to be clear, the pytype blocker is not about revealing existing errors, it's about the fact that pytype fails loudly and completely when it sees covariant=True. If that hasn't changed, then I'm afraid we can't do much else here.

jakevdp avatar Sep 18 '24 17:09 jakevdp

With respect to https://github.com/google/pytype/issues/1471, would this change be easier to merge if the output TypeVar were not marked as covariant? :thinking:

Edit: I'll double-check, but I think that the output var being invariant might cause other issues (which would then be due to the the annotation not being 100% correct)

lebrice avatar Sep 18 '24 17:09 lebrice

Yes, if this didn't use covariant typevars, it would be easier to merge. But my understanding from #14688 was that covariant typevars are required in order to correctly annotate jit.

jakevdp avatar Sep 18 '24 18:09 jakevdp

@jakevdp pytype treats all type variables as covariant IIRC, so maybe we can just suppress the warning for that particular type var?

superbobry avatar Sep 18 '24 19:09 superbobry

Re @superbobry @jakevdp : I added some # pytype: disable=not-supported-yet over the typevar definitions. If my understanding of pytype is correct, it will now simply drop the covariant arg, and treat those as regular typevars.

lebrice avatar Sep 18 '24 19:09 lebrice

@lebrice : I'd be very happy to help and put in the time required in order to fix (or at the very least silence) any such typing errors that are revealed as a result of this PR in other projects. Do you by chance have a kind of list of these public-facing, jax-based projects that could potentially have their CI fail as a result of this change? If not, I can also try to gather such a list myself.

I'll start with Flax, since that seems like the most obvious downstream project from my perspective. I'm able to get their Pytype-related CI steps to run without error, at least locally.

lebrice avatar Sep 18 '24 19:09 lebrice

Pulling in to run internal pytype tests

jakevdp avatar Sep 20 '24 18:09 jakevdp

I'm curious: If this ends up getting merged, will the Google-ML-Automation bot include my github username in the final commit? Or would someone on the inside need to add a Co-authored-by: Fabrice Normandin <[email protected]> in the commit message?

lebrice avatar Sep 21 '24 18:09 lebrice

I'm curious: If this ends up getting merged, will the Google-ML-Automation bot include my github username in the final commit? Or would someone on the inside need to add a Co-authored-by: Fabrice Normandin <[email protected]> in the commit message?

If this is merged, your actual unmodified commit would be added to the JAX source tree.

jakevdp avatar Sep 21 '24 20:09 jakevdp

HI @superbobry , would it help if I rebased this off of the most recent version of master? There might be new places where # type: ignore comments should be added. What do you think?

lebrice avatar Oct 11 '24 16:10 lebrice

Apologies for the silence. We can try, but it looks like pytype might need some fixing to accept the declarations in this PR.

superbobry avatar Oct 12 '24 16:10 superbobry

Hi @jakevdp @superbobry , I just rebased this again. Is there by any chance a Pytype GitHub Issue that I could take a look at, if I wanted to take a shot at fixing something, just out of curiosity? In the meantime I guess I'll just keep rebasing this once in a while.

lebrice avatar Nov 27 '24 17:11 lebrice

Thanks @lebrice! The relevant issue is google/pytype#1471, as Jake pointed out above. I think it is possible to fix, but it will be a massive change, since internal Google code definitely relies on this pytype behavior (even though it is obviously wrong).

superbobry avatar Nov 29 '24 10:11 superbobry