website icon indicating copy to clipboard operation
website copied to clipboard

jax typing problem in "RWKV, explained"

Open yuxi-liu-wired opened this issue 1 year ago • 0 comments

In the RWKV, explained notebook, there is the error

AttributeError: module 'jax.core' has no attribute 'is_opaque_dtype'

at cell

from beartype import beartype, roar

to_onehot = beartype(to_onehot)

try:
    print(to_onehot("hey"))
    assert False, "the code in this blog post is wrong!"
except roar.BeartypeCallHintException:
    print("🐻 rawr! that input type is not allowed")

try:
    print(to_onehot(N_VOCAB - 1))
except roar.BeartypeCallHintException:
    assert False, "the code in this blog post is wrong!"

The same error appears in later cells.

Full error traceback:

🐻 rawr! that input type is not allowed
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-39-ed99d1beb7a7>](https://localhost:8080/#) in <cell line: 11>()
     10 
     11 try:
---> 12     print(to_onehot(N_VOCAB - 1))
     13 except roar.BeartypeCallHintException:
     14     assert False, "the code in this blog post is wrong!"

<@beartype(__main__.to_onehot) at 0x793ae75051b0> in to_onehot(__beartype_func, __beartype_conf, __beartype_get_violation, __beartype_object_93932183760576, *args, **kwargs)

1 frames
[/usr/local/lib/python3.10/dist-packages/jaxtyping/_array_types.py](https://localhost:8080/#) in __instancecheck__(cls, obj)
    143             return False
    144 
--> 145         if has_jax and jax.core.is_opaque_dtype(obj.dtype):
    146             dtype = str(obj.dtype)
    147         elif hasattr(obj.dtype, "type") and hasattr(obj.dtype.type, "__name__"):

[/usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py](https://localhost:8080/#) in getattr(name)
     51       warnings.warn(message, DeprecationWarning, stacklevel=2)
     52       return fn
---> 53     raise AttributeError(f"module {module!r} has no attribute {name!r}")
     54 
     55   return getattr

AttributeError: module 'jax.core' has no attribute 'is_opaque_dtype'

yuxi-liu-wired avatar Feb 05 '24 23:02 yuxi-liu-wired