ty icon indicating copy to clipboard operation
ty copied to clipboard

functool partial and jax jit decorator interfere with argument types of function

Open chris-RNG opened this issue 2 months ago • 1 comments

Summary

I'm not sure if this is a bug, but since a few ty updates ago, I've started encountering this type error when combining functools.partial with the jax.jit decorator on a function.

from functools import partial

import jax
import jax.numpy as jnp
from jax import Array


@jax.jit
def sum_jit(a:Array, scale:float=1.0)->Array:
    return jnp.sum(a)*scale

@partial(jax.jit, static_argnames=["scale"])
def sum_partial_jit(a:Array, scale:float=1.0)->Array:
    return jnp.sum(a)*scale

def main():
    n=12
    a= jnp.ones(n)
    sum1 = sum_jit(a)
    sum2 = sum_partial_jit(a)
    print(f"sum1: {sum1}")
    print(f"sum1: {sum2}")

if __name__ == "__main__":
    main()
$ uvx ty check
error[invalid-argument-type]: Argument is incorrect
  --> main.py:20:27
   |
18 |     a= jnp.ones(n)
19 |     sum1 = sum_jit(a)
20 |     sum2= sum_partial_jit(a)
   |                           ^ Expected `(...) -> Unknown`, found `Array`
21 |     print(f"sum1: {sum1}")
22 |     print(f"sum1: {sum2}")
   |
info: Union variant `((...) -> Unknown, /) -> JitWrapped` is incompatible with this call site
info: Attempted to call union type `JitWrapped | (((...) -> Unknown, /) -> JitWrapped)`
info: rule `invalid-argument-type` is enabled by default

Found 1 diagnostic

Version

ty 0.0.3

chris-RNG avatar Dec 18 '25 13:12 chris-RNG

Thanks for the report!

carljm avatar Dec 23 '25 21:12 carljm