Owen L
Owen L
Is there any updates on this? I have previously worked with SPSA in TF (https://github.com/tensorflow/quantum/pull/653) and would be interested in working on this but don't want to do redundant labor.
Very similar thread in flax: https://github.com/google/flax/issues/2577. However, their solution of manually transposing doesn't work: `cnn_t = eqx.tree_at(lambda x: x.weight, cnn_t, jnp.flip(jnp.array(cnn.weight), (0, 1)).swapaxes(-2, -1))` still fails. And the approach of...
Forgot that flax had weights reversed. `cnn_t = eqx.tree_at(lambda x: x.weight, cnn_t, jnp.flip(jnp.array(cnn.weight), (2, 3)).swapaxes(0, 1))` Actually does work. Maybe the tooling for this should be documented somewhere? Or we...
Makes sense, these signatures could quickly become overwhelming. I will add it to the documentation. I definitely did not forget about jax.linear_transpose, because I didn't even know it existed! So...
It's probably not the most elegant solution, but when I had this problem before I also encountered weird problems with defining vjps for self/class methods. My solution was just to...
I doubt it will work in any consistent manner. Here is a simple example where trying to transform with it fails: ```python import equinox as eqx from jax import numpy...
Use `eqx.filter_value_and_grad`
I just meant, instead of jax.value_and_grad(loss) do the filter. This worked for me. The basic idea is that functions are not a Jax type so you can differentiate with respect...
What's the status of this PR? I'm encountering some equinox batch norm related trouble as I switch over some code from haiku, and I think it's related to some of...
Makes sense, I'll see about taking a stab at a small-ish change to allow for an option to align it with other packages. Tbh I haven't really followed the lore...