Mish icon indicating copy to clipboard operation
Mish copied to clipboard

Correct gain value during kaiming weight initialization

Open evanatyourservice opened this issue 5 years ago • 5 comments

Hello! Great work on this activation function! I've been using it in some of my projects with great success.

I want to let you know I found what the gain should be set at during kaiming weight initialization for Mish.

I found this experimentally using this code:

import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np

device = 'cpu'

def mish(x):
    return x * (torch.tanh(F.softplus(x)))

aa = []
bb = []
for n in range(100):
    with torch.no_grad():
        a = torch.randn(5000, 5000, device=device)
        b = a
        x = 0.0 + 0.00001 * n
        for i in range(10):
            l = torch.nn.Linear(5000, 5000, bias=False).to(device)
            torch.nn.init.kaiming_uniform_(l.weight, a=x)
            b = mish(l(b))
        aa.append(b.std().item())
        bb.append(x)
        print(x)
        print (f"in: {a.std().item():.8f}, out: {b.std().item():.8f}")
pd.DataFrame(data=aa, index=bb).plot(figsize=(20,8))

which was talked about here. The "a" hyperparameter for init.kaiming_uniform_ is not actually the gain but the negative slope of a leaky relu, so really I experimentally found the equivalent negative slope of mish for kaiming_uniform_ init. The actual gain is found internally by math.sqrt(2.0 / (1 + a ** 2)).

This is an example of the code output. I found through repeated experiments that 0.0003 results in the most consistently efficient throughput through the network, so it is almost the zero slop of relu but not quite. a=0 did produce okay results, as did say 0.001, but the best averaged over many runs is 0.0003. This is important because for deep networks the pytorch default value of sqrt(5) for initializing conv layers is not a good default value if using mish.

I now use something like

for m in self.modules():
    if isinstance(m, (nn.Conv1d, nn.Linear)):
        torch.nn.init.kaiming_uniform_(m.weight, a=0.0003)

evanatyourservice avatar Nov 16 '20 19:11 evanatyourservice

Thanks for the appreciation of my work. I'm glad Mish has been working well in your projects.

This is an interesting observation, I haven't extensively investigated into the optimal initialization schemes for Mish. I will check this more and verify this. Off context, have you tried Orthogonal initialization earlier?

digantamisra98 avatar Nov 17 '20 20:11 digantamisra98

No I have not! But now I want to after you've reminded me -- orthogonal used to work very well for me in RL experiments with complex networks. Maybe I'll try it out with the standard deviation experiment above and see what happens. With kaiming uniform I was getting some signal up to about 75 layers.

I've only messed around with the kaiming uniform and fan_in, but i'm sure the gain/slope setting would be the same for kaiming normal and fan_out, as the gain has only to do with the shape of the activation function with those. Since 0.0003 is so close to 0, I'd think using any defaults for relu would also work well for mish. I know Less Wright used nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') in his work with mish that beat some kaggle competitions... he uses nonlinearity=relu aka a=0. This makes me think orthogonal would at least work somewhat well with mish, since it works well with relu.

evanatyourservice avatar Nov 18 '20 19:11 evanatyourservice

Right, the reason I'm interested in orthogonal initialization because of this. Let me know if you have had any progress with orthogonal initialization.

digantamisra98 avatar Nov 24 '20 12:11 digantamisra98

Little update, definitely getting some interesting results with the orthogonal init. At first glance, it seems less finicky than kaiming init, more stable. I'm going to turn the experiment above into an optuna optimization problem to find the optimum gain because it takes so long manually and isn't exact.

With the orthogonal gain hyper, too small gain equals vanishing gradient, too large equals exploding, so with optuna I could narrow down the best gain that should allow for pretty deep propagation through a lot of layers. To really see the difference between orthogonal and kaiming, though, I'll have to do an actual training experiment. We'll see! I'll at least update my finding for the best setting for orthogonal gain for mish here shortly so others could experiment with this value. I'll also run kaiming through optuna as well to narrow that hyper down more exactly than 0.0003.

evanatyourservice avatar Dec 11 '20 16:12 evanatyourservice

@evanatyourservice Thanks for the update. Interesting, in orthogonal, you need to keep the init on the EOC otherwise you'll either have vanishing or exploding gradients. Keep me posted with your progress.

digantamisra98 avatar Dec 12 '20 00:12 digantamisra98