jaxtyping icon indicating copy to clipboard operation
jaxtyping copied to clipboard

Runtime type checking via `typeguard` causes `TypeError` due to array's having type `DeviceArray`.

Open jaymody opened this issue 2 years ago • 3 comments

I'm trying to use jaxtyping with runtime type checking via typeguard as described here. Here's my code:

import jax.numpy as jnp
from jaxtyping import Array, Float, jaxtyped
from typeguard import typechecked as typechecker


@jaxtyped
@typechecker
def foo(
    x: Float[Array, "n"],
    y: Float[Array, "n"],
) -> Float[Array, "n"]:
    return x + y

print(foo(jnp.arange(10), jnp.arange(10)))

However when I run the above script, I get the following error:

Traceback (most recent call last):
  File "/Users/jay/playground/myscript.py", line 14, in <module>
    print(foo(jnp.arange(10), jnp.arange(10)))
  File "/Users/jay/playground/.venv/lib/python3.9/site-packages/jaxtyping/decorator.py", line 41, in __call__
    return self.fn(*args, **kwargs)
  File "/Users/jay/playground/.venv/lib/python3.9/site-packages/typeguard/__init__.py", line 1032, in wrapper
    check_argument_types(memo)
  File "/Users/jay/playground/.venv/lib/python3.9/site-packages/typeguard/__init__.py", line 875, in check_argument_types
    raise TypeError(*exc.args) from None
TypeError: type of argument "x" must be jaxtyping.Float[ndarray, 'n']; got jaxlib.xla_extension.DeviceArray instead

Steps to reproduce my python environment (Note: I'm running this on an M1 Macbook Pro with macOS Monterey 12.2 (21D49)):

$ python -V
Python 3.9.10

$ python -m venv .venv

$ source .venv/bin/activate

$ python -m pip install --upgrade pip

$ python -m pip install "jax[cpu]==0.3.17" "jaxtyping==0.2.5"

jaymody avatar Sep 25 '22 01:09 jaymody

Ah, so I'm realizing it's because jnp.arange by default returns an array of type int. If I change it to print(foo(jnp.arange(10)*1.0, jnp.arange(10)*1.0)) I no longer get an error. Wondering if the error message can be more descriptive, or if this quirk is documented somewhere? Error message is a bit misleading.

jaymody avatar Sep 25 '22 01:09 jaymody

Right; something similar came up in #6. Indeed it would be great if the error message could include more information, but it's the typechecker that's raising the error (in this case typeguard) -- not jaxtyping. (All jaxtyping does is provide the types themselves.)

FWIW my usual approach to debugging this it to rerun with the debugger, so that I can check what types were passed myself. This can be done with either of:

python -m pdb -c continue your_script.py
ipython your_script.py --pdb

(I'd like a better solution to this too.)

patrick-kidger avatar Sep 25 '22 01:09 patrick-kidger

Yeah, that's the workaround I'm using as well to check the shapes and types if an error comes up. Maybe it's worth documenting this in API.md? I missed #6 in my search for a solution (which is on me tbh), but might be useful for the next person that will inevitably come across this without thoroughly checking the issues on github.

jaymody avatar Sep 25 '22 02:09 jaymody