penzai
penzai copied to clipboard
Fix dtype type annotations to use `jax.typing.DTypeLike`.
The named_axis functions full(), zeros() and ones() annotate their dtype parameters with the non-existent np.DTypeLike. This change replaces the annotations with jax.typing.DTypeLike | None to match the dtype parameter of their wrapped functions.