Jake Vanderplas
Jake Vanderplas
Overall comment here: the approach in this PR is one I've tried in other contexts and often ended up having to roll back: namely, you're trying to fix type check...
Also, I think some of this PR does not make sense separate from the `jit` annotation change, while some of it does make sense to be separate. A lot of...
> FYI this change is on top of https://github.com/google/jax/pull/18465, so I made some of your changes there and rebased over it. Sorry, I should have realized that.
Thanks! I think all this looks pretty reasonable now. We'll need to do `from __future__ import annotations` in a few places to fix the Python 3.9 errors, and there are...
Let's see what the tests say 😀
Test failures are real, it looks like the `out_indices` change is modifying the logic.
Hi - the `lstm` function accepts a number of arguments that must be static (e.g. array sizes, boolean flags, etc). If you are wrapping it in `jit`, then you should...
Thanks for the clear reproduction! That looks like a bug. In general we would not expect the results to be bitwise equivalent (`jit` rearranges floating point operations, so floating point...
Thanks! Assigning @hawkinsp to review this info.
It seems like this has been answered. Feel free to open another issue if you still have questions!