How to use nnx.custom_vjp with non-class arguments? Example needed
Hi @cgarciae
I wondered if it's possible to apply an nnx.custom_vjp to a function like:
@nnx.custom_vjp
def linear(m: MyLinear, x: jax.Array) -> jax.Array:
y = x @ m.kernel + m.bias
return y
But, I'm not sure what linear_bwd should return. I tried this:
def linear_fwd(m: nnx.Linear, x: jax.Array):
return linear(m, x), (m, x)
def linear_bwd(res, g):
m, x = res
inputs_g, outputs_g = g
kernel_grad = outputs_g[None,:] * x[:,None]
bias_grad = outputs_g
x_grad = m.kernel @ outputs_g
assert x_grad.shape == x.shape, 'Shape mismatch for x'
assert m.kernel.value.shape == kernel_grad.shape, 'Shape mismatch for kernel'
assert m.bias.value.shape == bias_grad.shape, 'Shape mismatch for bias'
m_g = nnx.State(dict(kernel=kernel_grad, bias=bias_grad))
x_g = nnx.State((x_grad,))
# x_g = nnx.State(dict(x=x_grad)) # also tried this
return (m_g, x_g)
The notebook is here
Ultimately I want to have an nnx.Module whose __call__ method is such that the module parameters are updated during the backward pass. Any guidance would be greatly appreciated!
Best,
Henry
Hey @hrbigelow, seems we have a bug. Thanks for reporting this!
The tangets for Module's should be State objects, however you are running into something else.
custom_vjp is very new and one of the trickiest transforms so I'd expect a few hiccups along the way.
Hi @cgarciae,
Thanks for looking into this. I can see this would be very tricky. I think it will be extremely useful if it can open up an opportunity to express different learning algorithms with minimal code. I look forward to using it.
By the way for context I did open a Jax discussion before I was aware of Flax NNX. (Had been using Haiku).
@cgarciae feel free to close this if you like. I've been using the fix and it works great.
Hi @hrbigelow, could you share how you resolved your issue?