equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Custom JVP/VJP definition within Module?

Open leonard-gleyzer opened this issue 1 month ago • 3 comments

Hello,

Is it possible to define custom jvp/vjp rules for a Module?

For example, suppose I have the following module that scales inputs before passing into another module:

class ScaledModel(eqx.Module):
    sub_model: eqx.Module
    scale: float = eqx.field(static=True)

    def __call__(self, x):
        scaled_x = x / self.scale
        return self.sub_model(scaled_x)

However, suppose that I want the gradient of ScaledModule when calling jax.jacfwd/jax.jacrev to return the gradient of the sub_model call with the scaled inputs, rather than the inputs to the ScaledModule instance.

For example, if I have

scaled_model = ScaledModule(eqx.nn.Linear(2, 1, key=jax.random.PRNGKey(0)), 2.0)

I want the output of

jax.jacfwd(scaled_model)(jnp.ones(2))

to be equivalent to calling

jax.jacfwd(scaled_model.sub_model)(jnp.ones(2) / scaled_model.scale)

and likewise for jax.jacrev.

For the JVP, I'm trying something like the following

import equinox as eqx
import jax
import jax.numpy as jnp


class ScaledModel1(eqx.Module):
    sub_model: eqx.Module
    scale: float = eqx.field(static=True)

    def __call__(self, x):
        scaled_x = x / self.scale
        return self.sub_model(scaled_x)


class ScaledModel2(eqx.Module):
    sub_model: eqx.Module
    scale: float

    @jax.custom_jvp
    def __call__(self, x):
        scaled_x = x / self.scale
        return self.sub_model(scaled_x)

    @__call__.defjvp
    def __call_jvp(self, primals, tangents):
        primal_out = self.sub_model(primals / self.scale)
        _, tangent_out = jax.jvp(
            self.sub_model, (primals / self.scale,), (tangents / self.scale,)
        )
        return primal_out, tangent_out


scaled_model1 = ScaledModel1(
    sub_model=eqx.nn.Linear(2, 1, key=jax.random.PRNGKey(0)), scale=2.0
)
scaled_model2 = ScaledModel2(
    sub_model=eqx.nn.Linear(2, 1, key=jax.random.PRNGKey(0)), scale=2.0
)

print(jax.jacfwd(scaled_model1)(jnp.ones(2)))
print(jax.jacfwd(scaled_model2)(jnp.ones(2)))

but this gives me TypeError: missing a required argument: 'x' for the second print statement.

leonard-gleyzer avatar May 17 '24 02:05 leonard-gleyzer