flax
flax copied to clipboard
WIP: Try to relax the constraint that attributes are immutable during setup(), while still having safe and correct cloning
Hi @jheek, @levskaya -- looking for a initial mini-review to see whether this idea makes sense at all, before I continue developing it.
This would help resolve many user questions and make common usecases easier (e.g. normalizing attributes during setup, or converting them into a canonical form like in https://github.com/google/flax/blob/main/flax/linen/linear.py#L84)
The downside of this is that the value of an attribute will depend on whether it is bound to a scope or not. I think it would make more sense to have a setup that is called irrespective of whether the module is bound or not. This way a user would always access the canonical value (after setup) externally like you would expect.
A backward compatible alternative to this is to have init
and setup
where init is meant for construction time initialization and setup
is meant for variable/rng construction that requires a scope and has potential side-effects
Yeah you're right, that's a good point. I wonder actually if the pattern that some folks use, where they override __post_init__
(and call super
appropriately) is actually broken in case you the module is then cloned. I'll check. We may want to consider effectively deprecating this pattern and forcing people to use a custom init
like you say, or maybe call it normalize_attrs
or something more explicit.
is actually broken in case you the module is then cloned
Typically normalize acts like a single iteration fixpoint so normalize(attrs) == normalize(normalize(attrs))
in which case cloning is not broken.
Actually the other big advantage of having a init/post_init pathway is that you can initialize submodules on unbound scopes. e.g. the following would work:
class AutoEncoder(nn.Module):
def init(self):
self.encoder = Encoder()
self.decoder = Decoder()
ae = AutoEncoder()
ae.decoder.apply(decoder_variables, ...)
Actually the other big advantage of having a init/post_init pathway is that you can initialize submodules on unbound scopes. e.g. the following would work:
That's actually really nice indeed. It's a bit unfortunate to have to have to explain the difference between "init" and "setup" but it's for a good reason. Maybe we can find better names? "prebind_init", "postbind_init"? Of course those are too verbose... I'm just thinking out loud here.
To make the below work:
class AutoEncoder(nn.Module):
def init(self):
self.encoder = Encoder()
self.decoder = Decoder()
We'd have to write more logic to config these lazily once the module is bound (as inside setup), wouldn't we?
We'd have to write more logic to config these lazily once the module is bound (as inside setup), wouldn't we?
We already have such logic for modules passed as attributes