flax icon indicating copy to clipboard operation
flax copied to clipboard

Change default initializer to Kaiming uniform

Open lucasb-eyer opened this issue 4 years ago • 6 comments

According to git history, it's been defaulting to lecun_normal since the beginning of times. lecun_normal is essentially the same as Xavier, which has been derived for tanh units. Nowadays nobody in their right mind uses tanh as internal non-linearity everywhere, but instead everybody uses ReLU and friends, for which kaiming_* has been derived. So I suggest defaulting to that.

For what it's worth, PyTorch does default to Kaiming's init too (conv, linear) although they even include a slight slope for leaky-relu by default. (And I'm using kaiming_uniform only because they use the uniform variant too, I think it doesn't really matter.)

lucasb-eyer avatar Apr 19 '20 14:04 lucasb-eyer

I don't feel very strongly about default optimizer but the reasoning sounds compelling to me. We should probably announce this breaking change in advance though

jheek avatar Apr 21 '20 09:04 jheek

I have a relatively strong opinion that having a default initializer was a mistake in all major recent frameworks. But I'm weird that way =)

Anyways, I might soon do some runs to see the effect on ResNet and can let you know about it in internal chat, although I don't expect much difference for "standard" BN-heavy models. But see for example Kaiming's 2015 paper for how important initializers are for non-BN models!

lucasb-eyer avatar Apr 21 '20 20:04 lucasb-eyer

To be safe I would want to rerun our examples to make sure this change doesn't break them... At least with transformer models I think people settled on xavier-uniforms after a lot of testing (I'm pretty sure we specify these explicitly in the models though).

levskaya avatar Apr 22 '20 15:04 levskaya

I think we should merge this, but first we need our end-to-end regression testing to be up in place. (See #144 for our first step towards this)

avital avatar Apr 24 '20 10:04 avital

This change passes tests, but has some regressions in the howto tests e.g., in checkpoints and distributed-training, loss exceeds the original threshold of 2.302 (2.40775).

Not sure how people feel about this, so just reporting.

danielsuo avatar May 29 '20 01:05 danielsuo

I'd like to see how this affects larger models. @mohitreddy1996 and @AlexeyG are in the process of setting up our end-to-end regression testing.

avital avatar May 29 '20 15:05 avital

Definitely announce a change like this as widely and loudly as possible! And provide a way to opt out of it for a while? I agree with Lucas that default initializers are super dangerous.

georgedahl avatar Dec 13 '22 05:12 georgedahl

After discussing this internally with @lucasb-eyer and @levskaya and given that the number of users have grown significantly since the time this issue was created, it seems infeasible to change this without breaking many users. So I am closing this issue.

marcvanzee avatar Dec 13 '22 07:12 marcvanzee