jax
jax copied to clipboard
[typing] Use ParamSpec in JIT annotation
This pull request would be a huge improvement for Jax users who use type checkers like MyPy or Pyright, which now support ParamSpec
.
Consider:
from jax import Array, jit
@jit
def f(x: Array, y: Array) -> Array:
pass
reveal_type(f.__call__)
Previous to this PR, users who decorate a function with jit
lose all annotations of the method:
Pyright: Type of "f.__call__" is "(*args: Unknown, **kwargs: Unknown) -> Unknown"
MyPy: Revealed type is "def (*args: Any, **kwargs: Any) -> Any"
After this PR, we get:
Pyright: Type of "f.__call__" is "(x: Array, y: Array) -> Array"
MyPy: Revealed type is "def (x: Array, y: Array) -> Array"
This unblocks a lot of type checking.
Part of https://github.com/google/jax/issues/12049 cc: @jakevdp
Not sure what's going on with the type checking, but perhaps it's not running MyPy 1.0.1?
Also, I'm still working on fixing this for applying jit
to methods, which currently works, but does not type check correctly.
After testing this a lot with my code, I think MyPy is still not ready to check code with complex ParamSpec
usage (although Pyright is).
If we check something like this in, it may trip up MyPy quite a bit. On the other hand, it has a the amazing effect of exposing plenty Jax type annotations that were previously hidden. Plenty of functions in jax.numpy
are stripped of perfectly good annotations by @jit
.
Also, there are some limitations in Python typing wrt method decorators. I've posted about this on python/typing. If people are using jit
on methods, then they will get errors, and there doesn't seem to be a universal fix.
What I will do is add typed decorators to my tjax library. Please let me know if you'd like to get this in to Jax sooner or if we should wait on it.
Thanks for looking at this! ParamSpec seems like it could be interesting; that said it still looks to be relatively unstable and probably the remarks here still apply.
Hasn't this already been discussed here: https://github.com/google/jax/issues/10311?
@JesseFarebro My mistake, I didn't find that when I searched! MyPy has progressed since then, but maybe still not enough.
@jakevdp Yes, fair enough. ParamSpec
seems to work for more cases than before, but it seems to be causing some errors still. I'll check back in a few months to see how support is coming.
Great news: Thanks to the recently merged https://github.com/python/mypy/pull/15837, it appears that the main MyPy error with this pull request may have been solved: https://github.com/python/mypy/issues/12169. Also, many of the MyPy errors that may have affected its usage may have been solved: https://github.com/python/mypy/issues/11846, https://github.com/python/mypy/issues/12986, https://github.com/python/mypy/issues/14802.
We've updated mypy to 1.4.1 in the meantime, and #17147 bumps it to 1.5.0 – can you sync your PR to the current main branch?
@jakevdp Done. FYI the MyPy pull I linked is merged, but it's not in any released MyPy yet.
@jakevdp MyPy 1.6 is out today, and it may now support using ParamSpec
as in this pull request. (See the section on ParamSpec improvements.)
Is there a process or plan for upgrading Jax to use MyPy 1.6? Looks like you guys have been upgrading regularly.
As soon as the new mypy version is mirrored at https://github.com/pre-commit/mirrors-mypy, we can bump the version in the pre-commit configuration here: https://github.com/google/jax/blob/5ed692809da29ff4fd9a89faee9af6d4b41c7056/.pre-commit-config.yaml#L30
It looks like the mirror update will happen automatically in about 12.5 hours: https://github.com/pre-commit/mirrors-mypy/blob/08cbc46b6e135adec84911b20e98e5bc52032152/.github/workflows/main.yml#L6
#18066 updates mypy to v1.6.0
Okay, I've rebased this to take advantage of the new MyPy version
Just FYI if you want to get rid of the typing_extensions
dependency, one option would be to follow SPEC 0, and drop Python 3.9 (since you already support Python 3.12). Python 3.10 has ParamSpec
.
We'll drop Python 3.9 sometime after April 2024, following NEP 29 (see https://jax.readthedocs.io/en/latest/deprecation.html).
sometime after April 2024, following NEP 29 (see
Fair enough, but isn't NEP 29 now superseded by SPEC 0?
Fair enough, but isn't NEP 29 now superseded by SPEC 0?
Not in JAX... our policy still explicitly points to NEP 29. We could have a discussion about changing that, but the discussion would have to happen.
I'm proposing switching to SPEC 0 in #18072
Well, it looks like the result of the SPEC-0 discussion is that we're going to support old Python releases for even longer than we did previously. If the current proposal goes in, we won't drop Python 3.9 until after July 2024. But for the purpose of the current PR, I believe we can use typing_extensions
until then.
Hi @NeilGirdhar – what's the status of this? Do you want to keep pushing on this approach?
@jakevdp Yes, I'm really looking forward to something like this making it in. I'll need to look at it tomorrow since I have plans tonight. Were you able to update to the new MyPy?
Yes, we currently run the CI on mypy 1.6.1: https://github.com/google/jax/blob/953f4670d88d2a1c168a4ad0b44ed940f6c58829/.pre-commit-config.yaml#L29-L30
Running MyPy 1.6.1, I'm still getting a lot of bad errors:
jax/experimental/sparse/linalg.py:102: error: Argument 2 to "__call__" of "Wrapped" has incompatible type "Array"; expected "P.args" [arg-type]
jax/experimental/sparse/linalg.py:102: error: Argument 3 to "__call__" of "Wrapped" has incompatible type "int"; expected "P.args" [arg-type]
jax/experimental/sparse/linalg.py:102: error: Argument 4 to "__call__" of "Wrapped" has incompatible type "Array | float | None"; expected "P.args" [arg-type]
It seems like more MyPy bugs. PyRight seems okay with it. I guess we'll have to wait longer? And someone should consider submitting this bug to MyPy?
Seems like maybe a bad interaction with functools.partial
?
Seems like maybe a bad interaction with
functools.partial
?
That makes sense. I'll look at it as soon as I have more time.
Filed a MyPy bug https://github.com/python/mypy/issues/16404
@jakevdp I think we should wait one more release of MyPy, which should fix the above errors. After that, although this change will not help MyPy users (the inferred type of a jit-decorated function will just be Callable
), this change won't hurt, and it will help Pyright users. What do you think?
@jakevdp I think we should wait one more release of MyPy
Sounds good
@jakevdp I've rebased this now. Is there a way to get the tests to run? https://github.com/python/mypy/issues/1484 is still unsolved, so there may only be modest gains in MyPy.
Pytype is failing with this message:
File ".../jax/_src/api.py", line 100, in <module>: argument "covariant" to TypeVar not supported yet [not-supported-yet]
Seems this is a known issue (https://github.com/google/pytype/issues/1471). Is there any way to do this kind of improvement without using covariant TypeVars?
@jakevdp
I'm going to approve and then pull it in to run internal pytype tests'
Thanks for taking the time!
Is there any way to do this kind of improvement without using covariant TypeVars?
The return type does need to be covariant (MyPy gives an error without it). One option is to do this in two steps. Just keep the parameter specification annotation and remove the return type annotation.
It's funny because in Python 3.12, we won't need these markers at all as the type checker can infer them, although I think we'd have to use the new PEP 695 syntax.
Should I remove the annotations on the return type?