flax icon indicating copy to clipboard operation
flax copied to clipboard

WIP: Try to relax the constraint that attributes are immutable during setup(), while still having safe and correct cloning

Open avital opened this issue 3 years ago • 7 comments

avital avatar Dec 16 '21 17:12 avital

Hi @jheek, @levskaya -- looking for a initial mini-review to see whether this idea makes sense at all, before I continue developing it.

This would help resolve many user questions and make common usecases easier (e.g. normalizing attributes during setup, or converting them into a canonical form like in https://github.com/google/flax/blob/main/flax/linen/linear.py#L84)

avital avatar Dec 16 '21 17:12 avital

The downside of this is that the value of an attribute will depend on whether it is bound to a scope or not. I think it would make more sense to have a setup that is called irrespective of whether the module is bound or not. This way a user would always access the canonical value (after setup) externally like you would expect.

A backward compatible alternative to this is to have init and setup where init is meant for construction time initialization and setup is meant for variable/rng construction that requires a scope and has potential side-effects

jheek avatar Dec 20 '21 10:12 jheek

Yeah you're right, that's a good point. I wonder actually if the pattern that some folks use, where they override __post_init__ (and call super appropriately) is actually broken in case you the module is then cloned. I'll check. We may want to consider effectively deprecating this pattern and forcing people to use a custom init like you say, or maybe call it normalize_attrs or something more explicit.

avital avatar Dec 23 '21 13:12 avital

is actually broken in case you the module is then cloned

Typically normalize acts like a single iteration fixpoint so normalize(attrs) == normalize(normalize(attrs)) in which case cloning is not broken.

Actually the other big advantage of having a init/post_init pathway is that you can initialize submodules on unbound scopes. e.g. the following would work:

class AutoEncoder(nn.Module):
  def init(self):
     self.encoder = Encoder()
     self.decoder = Decoder()

ae = AutoEncoder()
ae.decoder.apply(decoder_variables, ...)

jheek avatar Dec 23 '21 13:12 jheek

Actually the other big advantage of having a init/post_init pathway is that you can initialize submodules on unbound scopes. e.g. the following would work:

That's actually really nice indeed. It's a bit unfortunate to have to have to explain the difference between "init" and "setup" but it's for a good reason. Maybe we can find better names? "prebind_init", "postbind_init"? Of course those are too verbose... I'm just thinking out loud here.

avital avatar Dec 23 '21 13:12 avital

To make the below work:

class AutoEncoder(nn.Module):
  def init(self):
     self.encoder = Encoder()
     self.decoder = Decoder()

We'd have to write more logic to config these lazily once the module is bound (as inside setup), wouldn't we?

levskaya avatar Dec 24 '21 03:12 levskaya

We'd have to write more logic to config these lazily once the module is bound (as inside setup), wouldn't we?

We already have such logic for modules passed as attributes

jheek avatar Dec 24 '21 10:12 jheek