stable-baselines3 icon indicating copy to clipboard operation
stable-baselines3 copied to clipboard

[Enhancement]: Wrong gains for weight initialization

Open OliEfr opened this issue 2 years ago • 2 comments

Enhancement

The recommended gains for the weight init depend on the used activation function, see torch docs. However, as for now the used gains are statically implemented and always the same in ActorCriticPolicies. See here.

I recommend making the gains dependent on the activation function used(, i.e. probably mainly ReLU and tanh).

If you agree with this, I would like to implement it myself and PR.

Thanks and a good day!

To Reproduce

--

Relevant log output / Error message

--

System Info

--

Checklist

  • [X] I have checked that there is no similar issue in the repo
  • [X] I have read the documentation
  • [X] I have provided a minimal working example to reproduce the bug
  • [X] I've used the markdown code blocks for both code and stack traces.

OliEfr avatar Jun 16 '23 11:06 OliEfr

Hello, those gains are for orthogonal initialization only (https://pytorch.org/docs/stable/modules/torch/nn/init.html#orthogonal), when they are not used, the default pytorch initialization is used.

The gains are from OpenAI Baselines, to keep results consistent, but compared to other initialization, I didn't see any investigation on the effect of the gain so far (this would be already a good contribution), or at least if using tanh/relu with constant gain has an effect.

araffin avatar Jul 20 '23 12:07 araffin

Yes, I am talking about orthogonal init. I agree that it is useful to keep it consistent with OpenAI Baselines. A study regarding the effect of gain towards convergence will be useful.

It seems a coincidence (?) that the standard gain listed for ReLU for any initialization is also sqrt(2) Link. (The gain implemented in OpenAI Baselines and sb3 is also sqrt(2). Maybe they just used ReLU by default and never investigated the gain?)

One study that partly investigates impact of weight init is this. They find:

initializing the policy MLP with smaller weights in the last layer

network initialization scheme (C56) does not matter too much

OliEfr avatar Jul 21 '23 07:07 OliEfr