jaxtyping
jaxtyping copied to clipboard
Runtime type checking via `typeguard` causes `TypeError` due to array's having type `DeviceArray`.
I'm trying to use jaxtyping
with runtime type checking via typeguard
as described here
. Here's my code:
import jax.numpy as jnp
from jaxtyping import Array, Float, jaxtyped
from typeguard import typechecked as typechecker
@jaxtyped
@typechecker
def foo(
x: Float[Array, "n"],
y: Float[Array, "n"],
) -> Float[Array, "n"]:
return x + y
print(foo(jnp.arange(10), jnp.arange(10)))
However when I run the above script, I get the following error:
Traceback (most recent call last):
File "/Users/jay/playground/myscript.py", line 14, in <module>
print(foo(jnp.arange(10), jnp.arange(10)))
File "/Users/jay/playground/.venv/lib/python3.9/site-packages/jaxtyping/decorator.py", line 41, in __call__
return self.fn(*args, **kwargs)
File "/Users/jay/playground/.venv/lib/python3.9/site-packages/typeguard/__init__.py", line 1032, in wrapper
check_argument_types(memo)
File "/Users/jay/playground/.venv/lib/python3.9/site-packages/typeguard/__init__.py", line 875, in check_argument_types
raise TypeError(*exc.args) from None
TypeError: type of argument "x" must be jaxtyping.Float[ndarray, 'n']; got jaxlib.xla_extension.DeviceArray instead
Steps to reproduce my python environment (Note: I'm running this on an M1 Macbook Pro with macOS Monterey 12.2 (21D49)):
$ python -V
Python 3.9.10
$ python -m venv .venv
$ source .venv/bin/activate
$ python -m pip install --upgrade pip
$ python -m pip install "jax[cpu]==0.3.17" "jaxtyping==0.2.5"
Ah, so I'm realizing it's because jnp.arange
by default returns an array of type int. If I change it to print(foo(jnp.arange(10)*1.0, jnp.arange(10)*1.0))
I no longer get an error. Wondering if the error message can be more descriptive, or if this quirk is documented somewhere? Error message is a bit misleading.
Right; something similar came up in #6. Indeed it would be great if the error message could include more information, but it's the typechecker that's raising the error (in this case typeguard) -- not jaxtyping. (All jaxtyping does is provide the types themselves.)
FWIW my usual approach to debugging this it to rerun with the debugger, so that I can check what types were passed myself. This can be done with either of:
python -m pdb -c continue your_script.py
ipython your_script.py --pdb
(I'd like a better solution to this too.)
Yeah, that's the workaround I'm using as well to check the shapes and types if an error comes up. Maybe it's worth documenting this in API.md
? I missed #6 in my search for a solution (which is on me tbh), but might be useful for the next person that will inevitably come across this without thoroughly checking the issues on github.