equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Why uniform bias initializer?

Open jenkspt opened this issue 1 year ago • 14 comments

I noticed that the bias parameters in linear and conv modules use a uniform initializer. Is there a good justification for this? I noticed that PyTorch does this to. I was expecting a zero initializer, which is also what Flax uses.

jenkspt avatar Aug 19 '22 21:08 jenkspt

The justification is really just "that is what PyTorch does". (I copied the initialisers over when making Equinox.)

If you want a zero bias then this can be done by replacing the leaf:

linear = eqx.nn.Linear(...)
linear = eqx.tree_at(lambda l: l.bias, linear, jnp.zeros_like(linear.bias))

patrick-kidger avatar Aug 22 '22 11:08 patrick-kidger

Maybe this deserves a separate issues / feature request, but it was non-obvious to me as a first-time user (coming from Haiku) how to overwrite the default initialization scheme in nn.Linear.

I think it would be cool if there was a w_init argument (similar to Haiku), or perhaps just a comment in the docs explaining what is the most idiomatic approach to do this.

For what its worth, I find the eqx.tree_at approach to be a bit "too clever". Like, its not immediately obvious how this works to a new user. Maybe that's just me though.

As an aside, I am really enjoying Equinox so far and appreciate all the work you have put in to it!

angusturner avatar Sep 01 '22 04:09 angusturner

Thanks, I'm glad you like it!

Regarding adding this to the docs: yep, this is on the horizon (#185).

Regarding making this an additional argument: so far I've been resisting this as it adds quite a lot of complexity. It would add a lot of arguments that I think most folks don't use (PyTorch doesn't have this). It would necessitate adding a new eqx.nn.init namespace to hold the existing initialisers. Both of the above clutter the documentation.

But that isn't a strong feeling on my part. This request does come up every now and again, and I'd be happy to be persuaded otherwise?

(cc @typedfemale as we were talking about this recently)

patrick-kidger avatar Sep 01 '22 13:09 patrick-kidger

Is there a good way to change all bias initializers to zero?

jenkspt avatar Sep 01 '22 16:09 jenkspt

That is a fair point!

It is true that PyTorch doesn't take a w_init, but because stuff is mutable in PyTorch its very easy to just set linear.bias = ... (or do an in-place modify or whatever).

But I guess this is just teething pains of adjusting to immutability / PyTrees.

I will give this more thought as I adjust to the eqx.tree_at model surgery approach!

angusturner avatar Sep 02 '22 03:09 angusturner

Is there a good way to change all bias initializers to zero?

+1 for this. Right now I have a hacky method relying on recursive getattr to first get a sequence of target attribute strings and then pass them to eqx.tree_at for replacement. Would like to know of a cleaner method to achieve this.

paganpasta avatar Sep 07 '22 23:09 paganpasta

So I can think of two non-hacky methods for setting all biases to zero.

Option 1:

def linear(*args, **kwargs):
    out = eqx.nn.Linear(*args, **kwargs)
    out = eqx.tree_at(lambda l: l.bias, out, replace_fn=jnp.zeros_like)
    return out

and then using linear (or an analogous conv2d etc.) everywhere you were previously using eqx.nn.Linear.

Option 2:

model = ...
has_bias = lambda x: hasattr(x, "bias")
where = lambda m: [x.bias for x in jax.tree_util.tree_leaves(m, is_leaf=has_bias) if has_bias(x)]
model = eqx.tree_at(where, model, replace_fn=jnp.zeros_like)

FWIW I think I am hearing consensus in favour of this proposal! So perhaps let's just add it :D

The idea would be to:

  • Add an eqx.nn.init module with all the current initialisers.
  • Add extra arguments weight_init=..., bias_init=... to eqx.nn.{Linear, Conv, MultiheadAttention, ...}.
  • Make sure that any of the "more complicated" layers like MLP, MultiheadAttention forward such arguments to their sublayers.
  • Add tests.
  • Add the eqx.nn.init namespace to the documentation.

This is a fair amount of work so this probably isn't something I'll find time to implement in the near future. If anyone involved in this thread feels suitably motivated, then I'd be very happy to accept a PR on this.

patrick-kidger avatar Sep 08 '22 00:09 patrick-kidger

I'm not pushing to add init args. I agree that it adds unnecessary complexity (and work :sweat_smile:)

jenkspt avatar Sep 08 '22 02:09 jenkspt

Option 2 covers my use-cases with a minor modification; if has_bias(x) and x.bias is not None. Thanks!

paganpasta avatar Sep 08 '22 07:09 paganpasta

Does either option work with nn.MLP or would I need to do more surgery? It's unclear to me how I would use tree_at to access each linear layer but maybe my pytree skills are just weak...

jloganolson avatar Oct 21 '22 17:10 jloganolson

Approach 1: the above "Option 2" should work out-of-the-box: it will detect each linear layer as these all have a bias attribute.

Approach 2: if you wanted, you could also switch out has_bias with is_linear = lambda x: isinstance(x, eqx.nn.Linear), and then do exactly as before. As eqx.nn.MLP uses linear layers internally, and all its linear layers are the ones with biases, then this will work equally well.

Approach 3: if you're working with an MLP specifically then you could also just list out all its linear layers manually:

mlp = eqx.nn.MLP(...)
where = lambda m: [lin.bias for lin in m.layers]
mlp = tree_at(where, mlp, replace_fn=jnp.zeros_like)

Any one of these approaches is equally fine.

Ultimately I think this kind of model surgery is one of the greatest strengths of Equinox. It takes a little getting used to, but variations on the above pattern allow you to perform almost any kind of adjustment to your model.

patrick-kidger avatar Oct 22 '22 15:10 patrick-kidger