`Array` type doesn't deal nicely with `Tuple`s
def test(a: Tuple[Array, ...], b: Array) -> Array:
return a + b
Clearly, a + b is a Tuple[Array, ...] yet the LSP is unable to distinguish it in the case where we use Array.
Replacing Array with any other type does error out so its definitely due to the semantics of Array implementation.
I don't think this code works at runtime:
> import jax.numpy as jnp
> (jnp.arange(3),) + jnp.arange(3)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../site-packages/jax/_src/numpy/array_methods.py", line 587, in deferring_binary_op
raise TypeError(f"unsupported operand type(s) for {opchar}: "
TypeError: unsupported operand type(s) for +: 'tuple' and 'ArrayImpl'
I'm not sure what behaviour you were expecting?
Ah, sorry I was in a hurry so the issue is badly worded. The problem is the LSP not catching this rather obvious type error (basedpyright - which is the fork of pyright) and incidentally, caused trouble for me during runtime.
I noticed this is in the particular case when its Array so its likely more an implementation specific issue rather than an LSP one 🙂
Ah, right! :D So jaxtyping.Array is actually just a re-export of jax.Array. If there is an issue here then it'll either be in the typechecker or in JAX, I'm not sure which one!