`Module.init` does not contain KV-pair of modules without parameters
System information
-
Ubuntu 18.04.6 LTS - Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib:
Name: jax
Version: 0.3.25
---
Name: jaxlib
Version: 0.3.25+cuda11.cudnn805
---
Name: flax
Version: 0.6.3
- Python version: Python 3.8.16
Problem you have encountered:
The dictionary of parameters generated at Module initialisation does not include keys of the modules that do not contain any parameter
What you expected to happen:
That modules without parameters are represented as empty dictionaries, rather than being omitted completely.
Steps to reproduce:
Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.
import jax
import flax.linen as nn
class Identity(nn.Module):
def __call__(self, x):
return x
net = nn.Sequential([Identity(), nn.Dense(2)])
params = net.init(jax.random.PRNGKey(0), jax.numpy.ones((2, 2)))
print(params)
Current output:
# FrozenDict({
# params: {
# layers_1: {
# kernel: ...,
# bias: ...,
# },
# },
# })
#
Expected output:
# FrozenDict({
# params: {
+# layers_0: {},
# layers_1: {
# kernel: ...,
# bias: ...,
# },
# },
# })
Just wondering - if there is no params, is there a particular reason for having an empty dict for it? Because no params means the module will not look into the the layers_0 key for any data.
I see the point, it actually is an issue when you try to access the parameters directly. It's just counterintuitive that the parameters are not there.
My use case (a minimal example) is the following:
class Net(nn.Module):
a: nn.Module
b: nn.Module = Identity()
@nn.compact
def __call__(self, x, *args, **kwargs):
return self.a(self.b(x))
def b_wrapped(self, params, x):
return self.b.apply({"params": params["params"]["b"]}, x)
a = nn.Dense(2)
b = nn.Dense(2)
x = jnp.ones((2,))
net = Net(a, b)
params = net.init(jax.random.PRNGKey(0), x)
net.b_wrapped(params, x)
net = Net(a)
params = net.init(jax.random.PRNGKey(0), x)
net.b_wrapped(params, x) # KeyError: 'b'
In my specific instance Net would be a compound model like an RL agent, and a and b its components, like a representation network and a critic head (a real use case here that includes the workaround with a key check).
Actually, at some point we had the reverse where collections did contain the empty dictionaries but this led to problems and hard to read variable structures. For example if you have a "batch_stats" collection you wouldn't want a conv layer to have "batch_stats": {"conv0": {}}. I think the best approach here is to treat the existence of empty dictionaries in the variable tree as an implementation detail that shouldn't be relied upon.
I prefer the opposite, but I understand the design decision.
Perhaps the underlying issue here is that I couldn't find a better pattern to call only one part of a compound model (e.g., self.b(params, x)) without resorting to the self.b.apply and handling the parameters explicitly.
Happy if you want to close this, in case you don't want to go that direction, and also happy to hear any suggestion on the underlying issue.