flax icon indicating copy to clipboard operation
flax copied to clipboard

`Module.init` does not contain KV-pair of modules without parameters

Open epignatelli opened this issue 3 years ago • 4 comments

System information

  • Ubuntu 18.04.6 LTS
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib:
Name: jax
Version: 0.3.25
---
Name: jaxlib
Version: 0.3.25+cuda11.cudnn805
---
Name: flax
Version: 0.6.3
  • Python version: Python 3.8.16

Problem you have encountered:

The dictionary of parameters generated at Module initialisation does not include keys of the modules that do not contain any parameter

What you expected to happen:

That modules without parameters are represented as empty dictionaries, rather than being omitted completely.

Steps to reproduce:

Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.

import jax
import flax.linen as nn


class Identity(nn.Module):
    def __call__(self, x):
        return x


net = nn.Sequential([Identity(), nn.Dense(2)])

params = net.init(jax.random.PRNGKey(0), jax.numpy.ones((2, 2)))
print(params)

Current output:

# FrozenDict({
#     params: {
#         layers_1: {
#             kernel: ...,
#             bias: ...,
#         },
#     },
# })
#

Expected output:

# FrozenDict({
#     params: {
+#         layers_0: {},
#         layers_1: {
#             kernel: ...,
#             bias: ...,
#         },
#     },
# })

epignatelli avatar Jan 08 '23 09:01 epignatelli

Just wondering - if there is no params, is there a particular reason for having an empty dict for it? Because no params means the module will not look into the the layers_0 key for any data.

IvyZX avatar Jan 09 '23 18:01 IvyZX

I see the point, it actually is an issue when you try to access the parameters directly. It's just counterintuitive that the parameters are not there.

My use case (a minimal example) is the following:

class Net(nn.Module):
    a: nn.Module
    b: nn.Module = Identity()

    @nn.compact
    def __call__(self, x, *args, **kwargs):
        return self.a(self.b(x))

    def b_wrapped(self, params, x):
        return self.b.apply({"params": params["params"]["b"]}, x)



a = nn.Dense(2)
b = nn.Dense(2)
x = jnp.ones((2,))

net = Net(a, b)
params = net.init(jax.random.PRNGKey(0), x)
net.b_wrapped(params, x)

net = Net(a)
params = net.init(jax.random.PRNGKey(0), x)
net.b_wrapped(params, x)  # KeyError: 'b'

In my specific instance Net would be a compound model like an RL agent, and a and b its components, like a representation network and a critic head (a real use case here that includes the workaround with a key check).

epignatelli avatar Jan 10 '23 15:01 epignatelli

Actually, at some point we had the reverse where collections did contain the empty dictionaries but this led to problems and hard to read variable structures. For example if you have a "batch_stats" collection you wouldn't want a conv layer to have "batch_stats": {"conv0": {}}. I think the best approach here is to treat the existence of empty dictionaries in the variable tree as an implementation detail that shouldn't be relied upon.

jheek avatar Jan 11 '23 08:01 jheek

I prefer the opposite, but I understand the design decision.

Perhaps the underlying issue here is that I couldn't find a better pattern to call only one part of a compound model (e.g., self.b(params, x)) without resorting to the self.b.apply and handling the parameters explicitly.

Happy if you want to close this, in case you don't want to go that direction, and also happy to hear any suggestion on the underlying issue.

epignatelli avatar Jan 11 '23 11:01 epignatelli