Add Python-like apply method to Module to initialize weights and biases
Add a weight and bias initialization method to the nn.Module so we can set these values via an apply method like PyTorch that does this.
Reference to Python documentation here. Code here.
This code is required to complete issue #51.
I am trying to re-implement the following Python function that initializes the values of a module's weights and biases:
# better init, not covered in the original GPT video, but important, will cover in followup video
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
After adding some additional init function to Storch, I coded the following function:
private def init_weights[D <: FloatNN | ComplexNN](m: Module with HasWeight[D]): Unit =
m match
case lm : nn.Linear[_] =>
torch.nn.init.normal_(lm.weight, mean=0.0, std=0.02)
if true // lm.options.bias()
then
torch.nn.init.zeros_(lm.bias)
case _ : nn.Embedding[_] =>
???
case _ => ???
???
The first thing to note is that Moduledoes not have a weightmember so I had to use HasWeight[D]. The HasWeight[D] does not, unlike other traits in Module extend nn.Module.
The second thing of note is that we don't have a (adapted from HasWeight[D]):
trait HasBias[ParamType <: FloatNN | ComplexNN]:
def bias: Tensor[ParamType]
The issue I now have is to find a way to test if the Module has bias. The nn.Linear, for example, has LinearOptions that I could use, but it is private. I assume the objective is to keep this hidden to maintain an idiomatic Scala API. Moreover, not all modules will have options that include bias (for example Embedding).
The simplest solution is to have a hasBias(): Boolean method. The Module trait could have a default implementation that returns false. Any class that could have bias would have to override this method and access the options to return Boolean value.
Alternatively one could add a HasBias trait with the hasBias(): Boolean method. In this case overriding the method to return true may not be safe (depends on the order in which a class/trait is extended?)
Finally, we could try something fancy with type parameters so that bias existence is known at compile time, but I am uncertain of this.
Any suggestions on how I should proceed?
TIA
Sorry @hmf missed that somehow. I'd suggest we start with the simplest option, adding hasBias(): Boolean to Module.
Since enabling/disabling bias is often a constructor parameter, I think it is harder to type compared to HasWeights. We can still improve later if we see that it makes sense.