dm-haiku icon indicating copy to clipboard operation
dm-haiku copied to clipboard

Improve `params_dict()` support

Open trevorcai opened this issue 5 years ago • 1 comments

module.params_dict() can behave in surprising ways:

def f(x):
  mod = hk.Linear(8)
  print(mod.params_dict())  # empty during init, full during apply
  sequential = hk.Sequential([mod])
  print(sequential.params_dict())  # always empty
  out = sequential(x)
  print(sequential.params_dict())  # no longer empty
  return out

net = hk.transform(f)
p = net.init(jax.random.PRNGKey(428), np.zeros((2, 3)))
net.apply(p, np.zeros((2, 3)))

Prints:

{}
{}
{...}
{...}
{}
{...}

We should clean up & clearly define the desired semantics of params_dict().

trevorcai avatar Mar 12 '20 13:03 trevorcai

One specific desiderata: any call to params_dict should either throw an exception or return the same thing regardless of when it's called.

girving avatar Mar 12 '20 13:03 girving