jaxgptoolbox
jaxgptoolbox copied to clipboard
LipMLP: Xavier vs. He weight initialization used
Hi, congrats on your LipMLP work! According to the paper appendix A, the weight matrix should be initialized with the He init if ReLU is used. Regarding the code (https://github.com/ml-for-gp/jaxgptoolbox/blob/main/demos/lipschitz_mlp/model.py#L33) I was wondering whether the initialization of the weight matrix in the code is correct since ReLU is used: std = sqrt(2 / fan_in) instead of std = gain / sqrt(fan_in).