jax icon indicating copy to clipboard operation
jax copied to clipboard

[typing] Use ParamSpec in JIT annotation

Open NeilGirdhar opened this issue 1 year ago • 33 comments

This pull request would be a huge improvement for Jax users who use type checkers like MyPy or Pyright, which now support ParamSpec.

Consider:

from jax import Array, jit

@jit
def f(x: Array, y: Array) -> Array:
    pass

reveal_type(f.__call__)

Previous to this PR, users who decorate a function with jit lose all annotations of the method:

Pyright: Type of "f.__call__" is "(*args: Unknown, **kwargs: Unknown) -> Unknown"
MyPy: Revealed type is "def (*args: Any, **kwargs: Any) -> Any"

After this PR, we get:

Pyright: Type of "f.__call__" is "(x: Array, y: Array) -> Array"
MyPy: Revealed type is "def (x: Array, y: Array) -> Array"

This unblocks a lot of type checking.

Part of https://github.com/google/jax/issues/12049 cc: @jakevdp

NeilGirdhar avatar Feb 26 '23 02:02 NeilGirdhar

Not sure what's going on with the type checking, but perhaps it's not running MyPy 1.0.1?

Also, I'm still working on fixing this for applying jit to methods, which currently works, but does not type check correctly.

NeilGirdhar avatar Feb 26 '23 03:02 NeilGirdhar

After testing this a lot with my code, I think MyPy is still not ready to check code with complex ParamSpec usage (although Pyright is).

If we check something like this in, it may trip up MyPy quite a bit. On the other hand, it has a the amazing effect of exposing plenty Jax type annotations that were previously hidden. Plenty of functions in jax.numpy are stripped of perfectly good annotations by @jit.

Also, there are some limitations in Python typing wrt method decorators. I've posted about this on python/typing. If people are using jit on methods, then they will get errors, and there doesn't seem to be a universal fix.

What I will do is add typed decorators to my tjax library. Please let me know if you'd like to get this in to Jax sooner or if we should wait on it.

NeilGirdhar avatar Feb 26 '23 12:02 NeilGirdhar

Thanks for looking at this! ParamSpec seems like it could be interesting; that said it still looks to be relatively unstable and probably the remarks here still apply.

jakevdp avatar Feb 26 '23 13:02 jakevdp

Hasn't this already been discussed here: https://github.com/google/jax/issues/10311?

JesseFarebro avatar Feb 26 '23 16:02 JesseFarebro

@JesseFarebro My mistake, I didn't find that when I searched! MyPy has progressed since then, but maybe still not enough.

@jakevdp Yes, fair enough. ParamSpec seems to work for more cases than before, but it seems to be causing some errors still. I'll check back in a few months to see how support is coming.

NeilGirdhar avatar Feb 26 '23 19:02 NeilGirdhar

Great news: Thanks to the recently merged https://github.com/python/mypy/pull/15837, it appears that the main MyPy error with this pull request may have been solved: https://github.com/python/mypy/issues/12169. Also, many of the MyPy errors that may have affected its usage may have been solved: https://github.com/python/mypy/issues/11846, https://github.com/python/mypy/issues/12986, https://github.com/python/mypy/issues/14802.

NeilGirdhar avatar Aug 16 '23 21:08 NeilGirdhar

We've updated mypy to 1.4.1 in the meantime, and #17147 bumps it to 1.5.0 – can you sync your PR to the current main branch?

jakevdp avatar Aug 16 '23 21:08 jakevdp

@jakevdp Done. FYI the MyPy pull I linked is merged, but it's not in any released MyPy yet.

NeilGirdhar avatar Aug 16 '23 21:08 NeilGirdhar

@jakevdp MyPy 1.6 is out today, and it may now support using ParamSpec as in this pull request. (See the section on ParamSpec improvements.)

Is there a process or plan for upgrading Jax to use MyPy 1.6? Looks like you guys have been upgrading regularly.

NeilGirdhar avatar Oct 10 '23 19:10 NeilGirdhar

As soon as the new mypy version is mirrored at https://github.com/pre-commit/mirrors-mypy, we can bump the version in the pre-commit configuration here: https://github.com/google/jax/blob/5ed692809da29ff4fd9a89faee9af6d4b41c7056/.pre-commit-config.yaml#L30

It looks like the mirror update will happen automatically in about 12.5 hours: https://github.com/pre-commit/mirrors-mypy/blob/08cbc46b6e135adec84911b20e98e5bc52032152/.github/workflows/main.yml#L6

jakevdp avatar Oct 10 '23 19:10 jakevdp

#18066 updates mypy to v1.6.0

jakevdp avatar Oct 11 '23 19:10 jakevdp

Okay, I've rebased this to take advantage of the new MyPy version

NeilGirdhar avatar Oct 11 '23 21:10 NeilGirdhar

Just FYI if you want to get rid of the typing_extensions dependency, one option would be to follow SPEC 0, and drop Python 3.9 (since you already support Python 3.12). Python 3.10 has ParamSpec.

NeilGirdhar avatar Oct 11 '23 21:10 NeilGirdhar

We'll drop Python 3.9 sometime after April 2024, following NEP 29 (see https://jax.readthedocs.io/en/latest/deprecation.html).

jakevdp avatar Oct 11 '23 21:10 jakevdp

sometime after April 2024, following NEP 29 (see

Fair enough, but isn't NEP 29 now superseded by SPEC 0?

NeilGirdhar avatar Oct 11 '23 22:10 NeilGirdhar

Fair enough, but isn't NEP 29 now superseded by SPEC 0?

Not in JAX... our policy still explicitly points to NEP 29. We could have a discussion about changing that, but the discussion would have to happen.

jakevdp avatar Oct 11 '23 22:10 jakevdp

I'm proposing switching to SPEC 0 in #18072

jakevdp avatar Oct 11 '23 23:10 jakevdp

Well, it looks like the result of the SPEC-0 discussion is that we're going to support old Python releases for even longer than we did previously. If the current proposal goes in, we won't drop Python 3.9 until after July 2024. But for the purpose of the current PR, I believe we can use typing_extensions until then.

jakevdp avatar Oct 13 '23 21:10 jakevdp

Hi @NeilGirdhar – what's the status of this? Do you want to keep pushing on this approach?

jakevdp avatar Nov 03 '23 22:11 jakevdp

@jakevdp Yes, I'm really looking forward to something like this making it in. I'll need to look at it tomorrow since I have plans tonight. Were you able to update to the new MyPy?

NeilGirdhar avatar Nov 03 '23 22:11 NeilGirdhar

Yes, we currently run the CI on mypy 1.6.1: https://github.com/google/jax/blob/953f4670d88d2a1c168a4ad0b44ed940f6c58829/.pre-commit-config.yaml#L29-L30

jakevdp avatar Nov 03 '23 22:11 jakevdp

Running MyPy 1.6.1, I'm still getting a lot of bad errors:


jax/experimental/sparse/linalg.py:102: error: Argument 2 to "__call__" of "Wrapped" has incompatible type "Array"; expected "P.args"  [arg-type]
jax/experimental/sparse/linalg.py:102: error: Argument 3 to "__call__" of "Wrapped" has incompatible type "int"; expected "P.args"  [arg-type]
jax/experimental/sparse/linalg.py:102: error: Argument 4 to "__call__" of "Wrapped" has incompatible type "Array | float | None"; expected "P.args"  [arg-type]

It seems like more MyPy bugs. PyRight seems okay with it. I guess we'll have to wait longer? And someone should consider submitting this bug to MyPy?

NeilGirdhar avatar Nov 03 '23 22:11 NeilGirdhar

Seems like maybe a bad interaction with functools.partial?

jakevdp avatar Nov 03 '23 22:11 jakevdp

Seems like maybe a bad interaction with functools.partial?

That makes sense. I'll look at it as soon as I have more time.

NeilGirdhar avatar Nov 03 '23 22:11 NeilGirdhar

Filed a MyPy bug https://github.com/python/mypy/issues/16404

NeilGirdhar avatar Nov 04 '23 11:11 NeilGirdhar

@jakevdp I think we should wait one more release of MyPy, which should fix the above errors. After that, although this change will not help MyPy users (the inferred type of a jit-decorated function will just be Callable), this change won't hurt, and it will help Pyright users. What do you think?

NeilGirdhar avatar Nov 04 '23 20:11 NeilGirdhar

@jakevdp I think we should wait one more release of MyPy

Sounds good

jakevdp avatar Nov 06 '23 20:11 jakevdp

@jakevdp I've rebased this now. Is there a way to get the tests to run? https://github.com/python/mypy/issues/1484 is still unsolved, so there may only be modest gains in MyPy.

NeilGirdhar avatar Mar 13 '24 23:03 NeilGirdhar

Pytype is failing with this message:

File ".../jax/_src/api.py", line 100, in <module>: argument "covariant" to TypeVar not supported yet [not-supported-yet]

Seems this is a known issue (https://github.com/google/pytype/issues/1471). Is there any way to do this kind of improvement without using covariant TypeVars?

jakevdp avatar Mar 14 '24 00:03 jakevdp

@jakevdp

I'm going to approve and then pull it in to run internal pytype tests'

Thanks for taking the time!

Is there any way to do this kind of improvement without using covariant TypeVars?

The return type does need to be covariant (MyPy gives an error without it). One option is to do this in two steps. Just keep the parameter specification annotation and remove the return type annotation.

It's funny because in Python 3.12, we won't need these markers at all as the type checker can infer them, although I think we'd have to use the new PEP 695 syntax.

Should I remove the annotations on the return type?

NeilGirdhar avatar Mar 14 '24 00:03 NeilGirdhar