Tom Hennigan
Tom Hennigan
@LenaMartens you know JAX's AD machinery much better than I do, any thoughts here? > Speaking of meta_vjp, I noticed that vmap has a related TODO: "Allow configuration of params/state/rng...
Hi @samuela, our tests seem to pass fine on my workstation, github actions and our internal CI, so I wonder if this is something specific to your setup (e.g. not...
Internally we're testing with JAX/XLA at HEAD so I'm fairly confident they pass with the latest stable release too. I'll bump the versions we're using on GHA regardless in #370...
Apologies for the delay, I believe the majority of our internal testing is done on Intel CPUs however on GitHub I'm not really sure (we use whatever the actions runner...
Thanks for sharing, that's the first time I've seen those. We don't have any plans to do this at the moment, we have quite a few people using Haiku +...
This issue seems related deepmind/dm-haiku#83, perhaps something recently has changed?
https://groups.google.com/a/tensorflow.org/g/discuss/c/TiWgve-KERo/m/NgUohfTiAgAJ ^ I think this thread in TF is relevant too.
Hi @samuela, our tests are passing at head with {jax,jaxlib} == 0.3.5 (https://github.com/deepmind/dm-haiku/blob/main/requirements-jax.txt).
Hi Trevor, `hk.cond` passes all Haiku state in and out of the `cond` to allow module parameters and state to be created/updated inside the branches. In this case, it looks...
That is odd, if the issue is not Haiku specific you may find better answers on the JAX repo (there are more people looking there). One suggestion I would have...