flax icon indicating copy to clipboard operation
flax copied to clipboard

Improve RTD for initializers

Open billmark opened this issue 3 years ago • 7 comments

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

Problem you have encountered:

I was trying to exactly mimic the default bias initialization of PyTorch's Linear layer.

The first problem I encountered is that the bias_init arg to nn.Dense is not documented well enough. In particular, it's not clear what the parameters to the callback need to be. Of course I eventually read the source code but that shouldn't be necessary.

The second problem I encountered is that the initializers jax.nnlinitializers.*, e.g. lecun_normal(), can't be used easily (or maybe at all) with bias_init. These initializers expect to be passed a 2D shape (as one would have for kernel_init), but the bias weights are only 1D.

Suggested action:

  1. Document kernel_init and bias_init args better
  2. Figure out how to exactly emulate the PyTorch linear layer's bias_init. Note that you'll need to look at the PyTorch source code because the PyTorch docs aren't precise enough.

billmark avatar Jun 24 '21 17:06 billmark

Hi @billmark , thanks for your issue! Could you tell me in a bit more detail what you were missing from the description of kernel_init and bias_init arguments?

marcvanzee avatar Jul 09 '21 14:07 marcvanzee

As I said in the original bug, "In particular, it's not clear what the parameters to the callback need to be.".. That is, what are the args passed to the kernel_init and bias_init callbacks? How many of these args? What types/shapes? What do they mean? etc. There is no such information in the documentation.

billmark avatar Jul 13 '21 21:07 billmark

Had the same issue. I found this issue on jax repo which I think is relevant, the initializer functions are intentionally made to not work for 1D shapes. https://github.com/google/jax/issues/2075#issuecomment-578465814

I think maybe a page on the docs explaining the initializers would be interesting, or maybe link to the jax docs which lists the available functions. An example notebook explaining the module weights and biases initializations could be useful.

FelipeMartins96 avatar Jul 29 '21 15:07 FelipeMartins96

Thanks for the feedback, I've updated the issue description. Indeed, we don't seem to have our initializers on RTD, while we do have the activation functions. We can simply add them here as well: https://github.com/google/flax/blob/main/docs/flax.linen.rst

marcvanzee avatar Sep 06 '21 12:09 marcvanzee

I've been having this same issue as well; namely I'm not sure what the argument to kernel_init should be. Shouldn't I be able to do this when creating a nn.Dense() layer, for instance?

layer = nn.Dense(features=16, kernel_init = nn.initializers.lecun_normal(scale = 0.1))

This throws the following error when doing model.init(rng_key, dummy_input):

jax._src.traceback_util.UnfilteredStackTrace: TypeError: variance_scaling() got multiple values for argument 'scale'

It only works if I either A) don't provide an input argument to nn.initializers.lecun_normal() or B) use the following code instead, where I explicitly do the same thing as lecun_normal using the variance_scaling function.

nn.initializers.variance_scaling(0.1, "fan_in", "truncated_normal")

Sorry if this is not appropriate to paste here -- I can start a different issue if necessary. But it seems related to the problem of not knowing what kernel_init argument expects.

conorheins avatar Sep 27 '21 14:09 conorheins

I t only works if I either A) don't provide an input argument to nn.initializers.lecun_normal() or B) use the following code instead, where I explicitly do the same thing as lecun_normal using the variance_scaling function.

I agree it's a bit frustating that you can only rescale the variance_scaling and not lecun_normal, kaiming, etc. These initializers come from the jax.nn.initializers module though so it's best to open an issue in the google/jax repo about this

jheek avatar Sep 27 '21 16:09 jheek

I'll look into this. I'll document the signature for each and clarify that bias_initializers don't work with the variance_scaling family.

cgarciae avatar Sep 08 '22 22:09 cgarciae

@cgarciae any update on this?

marcvanzee avatar Dec 12 '22 15:12 marcvanzee

@zaxtax will be taking the issue :)

cgarciae avatar Dec 12 '22 15:12 cgarciae