flax
flax copied to clipboard
Improve RTD for initializers
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
Problem you have encountered:
I was trying to exactly mimic the default bias initialization of PyTorch's Linear layer.
The first problem I encountered is that the bias_init arg to nn.Dense is not documented well enough. In particular, it's not clear what the parameters to the callback need to be. Of course I eventually read the source code but that shouldn't be necessary.
The second problem I encountered is that the initializers jax.nnlinitializers.*, e.g. lecun_normal(), can't be used easily (or maybe at all) with bias_init. These initializers expect to be passed a 2D shape (as one would have for kernel_init), but the bias weights are only 1D.
Suggested action:
- Document kernel_init and bias_init args better
- Figure out how to exactly emulate the PyTorch linear layer's bias_init. Note that you'll need to look at the PyTorch source code because the PyTorch docs aren't precise enough.
Hi @billmark , thanks for your issue! Could you tell me in a bit more detail what you were missing from the description of kernel_init
and bias_init
arguments?
As I said in the original bug, "In particular, it's not clear what the parameters to the callback need to be.".. That is, what are the args passed to the kernel_init and bias_init callbacks? How many of these args? What types/shapes? What do they mean? etc. There is no such information in the documentation.
Had the same issue. I found this issue on jax repo which I think is relevant, the initializer functions are intentionally made to not work for 1D shapes. https://github.com/google/jax/issues/2075#issuecomment-578465814
I think maybe a page on the docs explaining the initializers would be interesting, or maybe link to the jax docs which lists the available functions. An example notebook explaining the module weights and biases initializations could be useful.
Thanks for the feedback, I've updated the issue description. Indeed, we don't seem to have our initializers on RTD, while we do have the activation functions. We can simply add them here as well: https://github.com/google/flax/blob/main/docs/flax.linen.rst
I've been having this same issue as well; namely I'm not sure what the argument to kernel_init
should be. Shouldn't I be able to do this when creating a nn.Dense()
layer, for instance?
layer = nn.Dense(features=16, kernel_init = nn.initializers.lecun_normal(scale = 0.1))
This throws the following error when doing model.init(rng_key, dummy_input)
:
jax._src.traceback_util.UnfilteredStackTrace: TypeError: variance_scaling() got multiple values for argument 'scale'
It only works if I either A) don't provide an input argument to nn.initializers.lecun_normal()
or B) use the following code instead, where I explicitly do the same thing as lecun_normal using the variance_scaling
function.
nn.initializers.variance_scaling(0.1, "fan_in", "truncated_normal")
Sorry if this is not appropriate to paste here -- I can start a different issue if necessary. But it seems related to the problem of not knowing what kernel_init
argument expects.
I t only works if I either A) don't provide an input argument to nn.initializers.lecun_normal() or B) use the following code instead, where I explicitly do the same thing as lecun_normal using the variance_scaling function.
I agree it's a bit frustating that you can only rescale the variance_scaling and not lecun_normal, kaiming, etc. These initializers come from the jax.nn.initializers
module though so it's best to open an issue in the google/jax repo about this
I'll look into this. I'll document the signature for each and clarify that bias_initializer
s don't work with the variance_scaling
family.
@cgarciae any update on this?
@zaxtax will be taking the issue :)