flax icon indicating copy to clipboard operation
flax copied to clipboard

Wrong parameter names when nesting Modules within flax transformations

Open PhilipVinc opened this issue 4 months ago • 3 comments

Hi, I have a complex case where I nest different submodules inside each other, which results in what I think is a wrong parameter name.

MWE:

import jax
import jax.numpy as jnp
import flax
import flax.linen as nn

# This network should be storing the parameters as 
# {'subnet': {....}, 'local_pars': ()}
class Net(nn.Module):
    subnet : nn.Module
    def setup(self):
        self.local_pars = self.param('some_pars', nn.initializers.zeros, (), float)
    def __call__(self, x):
        return self.subnet(x) + self.local_pars

# I expect this network to store parameters as
# {'vnet': {subnet structure...}}
class VNet(nn.Module):
    subneta : nn.Module
    some_args : dict
    
    def setup(self):
        cstrctor = nn.vmap(self.subneta, variable_axes={'params':0}, split_rngs={'params':True}, in_axes=0, out_axes=0)
        self.vnet = cstrctor(**self.some_args)
    def __call__(self, x):
        return self.vnet(x)

s = jnp.ones((3, 4))
k = jax.random.key(1)
net = VNet(subneta=Net, some_args=flax.core.freeze({'subnet': nn.Dense(features=1)}))
v_pars = net.init(k, s)
v_pars['params']
jax.tree_map(lambda x:x.shape, v_pars)
# {'params': {'some_args_subnet': {'bias': (3, 1), 'kernel': (3, 4, 1)},
#  'vnet': {'some_pars': (3,)}}}

I would expect the network parameter to be stored as a dictionary of the subnetwork's structure, as follows:

{'params': {'vnet': {'some_pars': (3,), 'subnet': {'bias': (3, 1), 'kernel': (3, 4, 1)}}}}

but instead the parameters of the subnetwork are split in two blocks.

We noticed that the bug disappears if the some_args dictionary is removed, and the keyword arguments are passed directly.

class VNet(nn.Module):
    subneta : nn.Module
    
    def setup(self):
        cstrctor = nn.vmap(self.subneta, variable_axes={'params':0}, split_rngs={'params':True}, in_axes=0, out_axes=0)
        self.vnet = cstrctor(**{'subnet': nn.Dense(features=1)})

    def __call__(self, x):
        return self.vnet(x)

s = jnp.ones((3, 4))
k = jax.random.key(1)
net = VNet(subneta=Net)
v_pars = net.init(k, s)
jax.tree_map(lambda x:x.shape, v_pars)
# {'params': {'vnet': {'some_pars': (3,),
#    'subnet': {'bias': (3, 1), 'kernel': (3, 4, 1)}}}}

cc @adrien-kahn

PhilipVinc avatar Mar 11 '24 18:03 PhilipVinc