equinox
equinox copied to clipboard
Custom JVP/VJP definition within Module?
Hello,
Is it possible to define custom jvp/vjp rules for a Module?
For example, suppose I have the following module that scales inputs before passing into another module:
class ScaledModel(eqx.Module):
sub_model: eqx.Module
scale: float = eqx.field(static=True)
def __call__(self, x):
scaled_x = x / self.scale
return self.sub_model(scaled_x)
However, suppose that I want the gradient of ScaledModule
when calling jax.jacfwd
/jax.jacrev
to return the gradient of the sub_model
call with the scaled inputs, rather than the inputs to the ScaledModule
instance.
For example, if I have
scaled_model = ScaledModule(eqx.nn.Linear(2, 1, key=jax.random.PRNGKey(0)), 2.0)
I want the output of
jax.jacfwd(scaled_model)(jnp.ones(2))
to be equivalent to calling
jax.jacfwd(scaled_model.sub_model)(jnp.ones(2) / scaled_model.scale)
and likewise for jax.jacrev
.
For the JVP, I'm trying something like the following
import equinox as eqx
import jax
import jax.numpy as jnp
class ScaledModel1(eqx.Module):
sub_model: eqx.Module
scale: float = eqx.field(static=True)
def __call__(self, x):
scaled_x = x / self.scale
return self.sub_model(scaled_x)
class ScaledModel2(eqx.Module):
sub_model: eqx.Module
scale: float
@jax.custom_jvp
def __call__(self, x):
scaled_x = x / self.scale
return self.sub_model(scaled_x)
@__call__.defjvp
def __call_jvp(self, primals, tangents):
primal_out = self.sub_model(primals / self.scale)
_, tangent_out = jax.jvp(
self.sub_model, (primals / self.scale,), (tangents / self.scale,)
)
return primal_out, tangent_out
scaled_model1 = ScaledModel1(
sub_model=eqx.nn.Linear(2, 1, key=jax.random.PRNGKey(0)), scale=2.0
)
scaled_model2 = ScaledModel2(
sub_model=eqx.nn.Linear(2, 1, key=jax.random.PRNGKey(0)), scale=2.0
)
print(jax.jacfwd(scaled_model1)(jnp.ones(2)))
print(jax.jacfwd(scaled_model2)(jnp.ones(2)))
but this gives me TypeError: missing a required argument: 'x'
for the second print statement.