Roy Frostig
Roy Frostig
A rough update on where the implementation is at present: * Batching, forward-mode AD (e.g. `jvp`, `jacfwd`), and compilation are all supported. * Reverse-mode AD (specifically linearization by partial evaluation,...
Thanks! Would you mind [squashing your commits](https://jax.readthedocs.io/en/latest/contributing.html#single-change-commits-and-pull-requests) to just one commit?
What GPU do you have, and what is the output of `nvidia-smi`?
Indeed, the move away from a pytree is what's blocking us. Aside from that, we also have some operations to consider implementing, as categorized nicely in @LenaMartens' description for #8381....
> Just curious, and didn't know where to ask. Is there a reason that `k.split()` does something completely different than `split(k)` if `k` is a new-style `KeyArray`? Could this be...
> It seems like split should prepend rather than append the dimension for consistency? `split` prepends in correspondence with the very common unpacking usage, e.g.: ```python key1, key2, key3 =...
I see. Thanks for spelling that out. Although key arrays are not entirely numpy-like, they are still partly so. For example, they support transposition, reshaping, and a few others, and...
> But I assume only affects pre-allocation, not freeing the memory afterwards? Device memory for an array ought to be freed once all Python references to it drop, i.e. upon...
@dfm recently put together a nice comprehensive [tutorial repository](https://github.com/dfm/extending-jax) on writing custom ops. I recommend taking a look! > Thanks for making Jax, it's awesome. Thank you for the kind...
xref: #1870