catalyst
catalyst copied to clipboard
[BUG] Some annotated functions fail to be used inline with `@qjit`
trafficstars
For example,
>>> qjit(jax.scipy.linalg.expm)(x)
TypeError: Argument 'ArrayLike' of type <class 'str'> is not a valid JAX type
The same JAX function works fine when used within a defined function that is qjitted.
This is because of unexpected type annotations:
def expm(A: ArrayLike, *, upper_triangular: bool = False, max_squarings: int = 16) -> Array: