dm-haiku
dm-haiku copied to clipboard
Improve `params_dict()` support
module.params_dict() can behave in surprising ways:
def f(x):
mod = hk.Linear(8)
print(mod.params_dict()) # empty during init, full during apply
sequential = hk.Sequential([mod])
print(sequential.params_dict()) # always empty
out = sequential(x)
print(sequential.params_dict()) # no longer empty
return out
net = hk.transform(f)
p = net.init(jax.random.PRNGKey(428), np.zeros((2, 3)))
net.apply(p, np.zeros((2, 3)))
Prints:
{}
{}
{...}
{...}
{}
{...}
We should clean up & clearly define the desired semantics of params_dict().
One specific desiderata: any call to params_dict should either throw an exception or return the same thing regardless of when it's called.