flax
flax copied to clipboard
Wrong parameter names when nesting Modules within flax transformations
Hi, I have a complex case where I nest different submodules inside each other, which results in what I think is a wrong parameter name.
MWE:
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
# This network should be storing the parameters as
# {'subnet': {....}, 'local_pars': ()}
class Net(nn.Module):
subnet : nn.Module
def setup(self):
self.local_pars = self.param('some_pars', nn.initializers.zeros, (), float)
def __call__(self, x):
return self.subnet(x) + self.local_pars
# I expect this network to store parameters as
# {'vnet': {subnet structure...}}
class VNet(nn.Module):
subneta : nn.Module
some_args : dict
def setup(self):
cstrctor = nn.vmap(self.subneta, variable_axes={'params':0}, split_rngs={'params':True}, in_axes=0, out_axes=0)
self.vnet = cstrctor(**self.some_args)
def __call__(self, x):
return self.vnet(x)
s = jnp.ones((3, 4))
k = jax.random.key(1)
net = VNet(subneta=Net, some_args=flax.core.freeze({'subnet': nn.Dense(features=1)}))
v_pars = net.init(k, s)
v_pars['params']
jax.tree_map(lambda x:x.shape, v_pars)
# {'params': {'some_args_subnet': {'bias': (3, 1), 'kernel': (3, 4, 1)},
# 'vnet': {'some_pars': (3,)}}}
I would expect the network parameter to be stored as a dictionary of the subnetwork's structure, as follows:
{'params': {'vnet': {'some_pars': (3,), 'subnet': {'bias': (3, 1), 'kernel': (3, 4, 1)}}}}
but instead the parameters of the subnetwork are split in two blocks.
We noticed that the bug disappears if the some_args
dictionary is removed, and the keyword arguments are passed directly.
class VNet(nn.Module):
subneta : nn.Module
def setup(self):
cstrctor = nn.vmap(self.subneta, variable_axes={'params':0}, split_rngs={'params':True}, in_axes=0, out_axes=0)
self.vnet = cstrctor(**{'subnet': nn.Dense(features=1)})
def __call__(self, x):
return self.vnet(x)
s = jnp.ones((3, 4))
k = jax.random.key(1)
net = VNet(subneta=Net)
v_pars = net.init(k, s)
jax.tree_map(lambda x:x.shape, v_pars)
# {'params': {'vnet': {'some_pars': (3,),
# 'subnet': {'bias': (3, 1), 'kernel': (3, 4, 1)}}}}
cc @adrien-kahn