mypy doesn't recognize type conversions of `eqx.field(converter=...)`
Hi @patrick-kidger,
I've stumbled over the problem that mypy does not understand that a converter may change a type:
import equinox as eqx
import jax
from jaxtyping import Array
class MyModule(eqx.Module):
foo: Array = eqx.field(converter=jax.numpy.asarray)
mymodule = MyModule(1.0)
assert isinstance(mymodule.foo, jax.Array)
If you run mypy thisexample.py you'll get: error: Argument 1 to "MyModule" has incompatible type "float"; expected "Array" [arg-type]
If you instead annotate foo with ArrayLike (or the input type to the converter function) it will complain on subsequent steps, e.g.:
import equinox as eqx
import jax
from jaxtyping import ArrayLike
class MyModule(eqx.Module):
foo: ArrayLike = eqx.field(converter=jax.numpy.asarray)
mymodule = MyModule(1.0)
_ = mymodule.foo.shape
If you run mypy thisexample.py now you'll get:
error: Item "int" of "Array | ndarray[Any, Any] | numpy.bool[builtins.bool] | number[Any, int | float | complex] | int | float | complex" has no attribute "shape" [union-attr]
error: Item "float" of "Array | ndarray[Any, Any] | numpy.bool[builtins.bool] | number[Any, int | float | complex] | int | float | complex" has no attribute "shape" [union-attr]
error: Item "complex" of "Array | ndarray[Any, Any] | numpy.bool[builtins.bool] | number[Any, int | float | complex] | int | float | complex" has no attribute "shape" [union-attr]
Is there a way to automatically propagate the type information from converters for type checkers?
I think mypy doesn't support this yet: https://github.com/python/mypy/issues/17547
You could try using pyright instead, which does.
Thanks @patrick-kidger, I can confirm that these errors are not there with pyright!
(feel free to close this issue whenever you like)