jax
jax copied to clipboard
Fix Pyright issue with the type of NotMapped
Newest pyright is not able to figure out that NotMapped
is a type alias, the correct type annotation forces it to understand.
Can be reproduced and/or tested by running pyright on
from jax.interpreters.batching import NotMapped
x: NotMapped
discovered this when updating pyright for Equinox