Patrick Kidger
Patrick Kidger
For anyone else coming across this: use `jaxtyping.PRNGKeyArray` and you should be good-to-go.
I think you're the first one to ask for them! :) I'd be happy to take a pull request that adds these.
Do you know how this operation is defined (mathemtically) in other frameworks e.g. PyTorch?
If this is given by a transpose, could this be implemented using `jax.linear_transpose`, without requiring any changes in JAX?
No opinions as I've not really used it I'm afraid :) Why would transposing be less efficient? XLA actually has a custom op (https://openxla.org/xla/operation_semantics#selectandscatter) for specifically the operation you're trying...
Thanks for the question! This is actually the reason that the breaking version number bumped with the v0.3.0 relese. This should now be described without the `size.` prefix: ```python @jaxtyped(typechecker=beartype)...
Hehe, for completeness I should in turn acknowledge @knyazer, who actually put our IPython extension together! I just sat back and approved PRs 😁
Ah yeah. FWIW I believe the impact of this bug is that debuggers (ipdb) will not work correctly whilst a magic is loaded.
> jaxtyping shapes/dtypes too Does it work with checking consistency across multiple jaxtyping shapes/dtypes? (I.e. the bit that the `@jaxtyped` decorator is responsible for enforcing?)
So I think this is just a choice that has been made in the implementation of the LSP -- since you're using VSCode I'm assuming it's probably pylance -- to...