equivariant-MLP
equivariant-MLP copied to clipboard
Breaking the Equivariant when using the haiku.module class
Hello, it's a great project!
I tried to use the EMLP
with dm-haiku, and I write two version of codes in different ways. The first is directly using the emlp.nn.haiku
, and the second is using the haiku.module
class. But I found that the second version will break the equivariance of neural network. In my view, the two versions have no difference in the architecture of neural network. Could you tell me if something wrong with my way of using EMLP
?
Attached is the codes. The first is:
import emlp.nn.haiku as ehk
import haiku as hk
from emlp.reps import V
from emlp.groups import SO
from jax import random
import jax.numpy as jnp
n = 10
dim = 3
G = SO(dim)
rep_in = n*V(G)
rep_out = n*V(G)
model = ehk.EMLP(rep_in, rep_out, group=G, num_layers=2, ch=256)
net = hk.without_apply_rng(hk.transform(model))
key = random.PRNGKey(0)
x = random.normal(key, (n*dim,))
params = net.init(random.PRNGKey(42), x)
v = net.apply(params, x)
g = G.sample()
x_1 = rep_in.rho(g)@x
v_1 = net.apply(params, x_1)
v_2 = rep_out.rho(g)@v
print(f"v(𝜌(g)x) =\n{v_1}")
print(f"𝜌(g)v(x) =\n{v_2}")
assert jnp.allclose(v_1, v_2)
and the second is:
import emlp.nn.haiku as ehk
from emlp.reps import V
from emlp.groups import SO
import haiku as hk
from jax import random
import jax.numpy as jnp
class test_EMLP(hk.Module):
def __init__(self, n, dim, group, num_layers, ch, name=None):
super().__init__(name=name)
self.n = n
self.dim = dim
self.group = group(dim)
self.rep_in = self.n*V(self.group)
self.rep_out = self.n*V(self.group)
self.num_layers = num_layers
self.ch = ch
self.e_mlp =self.e_mlp()
def e_mlp(self):
return ehk.EMLP(self.rep_in,
self.rep_out,
group=self.group,
num_layers=self.num_layers,
ch=self.ch)
def __call__(self, x):
return self.e_mlp(x)
def forward_fn(x):
model = test_EMLP(n=10, dim=3, group=SO, num_layers=2, ch=256)
return model(x)
net = hk.without_apply_rng(hk.transform(forward_fn))
n = 10
dim = 3
G = SO(dim)
rep_in = n*V(G)
rep_out = n*V(G)
key = random.PRNGKey(1)
x = random.normal(key, (n*dim,))
params = net.init(random.PRNGKey(42), x)
v = net.apply(params, x)
g = G.sample()
x_1 = rep_in.rho(g)@x
v_1 = net.apply(params, x_1)
v_2 = rep_out.rho(g)@v
print(f"v(𝜌(g)x) =\n{v_1}")
print(f"𝜌(g)v(x) =\n{v_2}")
assert jnp.allclose(v_1, v_2)
Yeah I believe this is because in Haiku, the init gets called on every forward pass of the model.
In EMLP for the BiLinear
layer, a random subset of all possible bilinear interactions are chosen to limit the size and computational cost of the layer. However, this random subset will be different on different instantiations of the model (unless the random seed is held fixed), and therefore constructing the EMLP this way with Haiku will actually be evaluating slightly different models.
You can check by calling the above net.apply(params, x)
for your test_EMLP
multiple times.
If you look in emlp/nn/haiku.py, the way that I get around this is to have the relevant constructors be standard functions which will return the input output mapping through a haiku module. This way we can ensure that the precomputation of the equivariant bases (and choice of random subsets) is only performed once. (You can notice that your test_EMLP
will be much slower than ehk.EMLP
because it must compute the equivariant basis multiple times, although the caching may ameliorate this somewhat).
If you want to use it inside another Haiku module, my advice would be either
- fix the random seed (still not ideal because it will end up repeating some fixed computations)
- write the module constructor as a stateless function like with
ehk.EMLP
- Possibly there is some haiku specific mechanism for storing this precomputed state, and one could explore that
Thanks for your early reply! I will try it soon.