Common neural network operations
I am trying the provided examples, particularly those associated with neural networks. In particular, the following example does not seem to work:
x = einx.multiply("... [c]", x, einn.param(init=1e-5))
einn.param is missing a required argument. I could not find a similar example that is standalone. Even with AI, I cannot get the param() example to work. There is also no example in the tests/ folder that works with torch. Could somebody please provide me a functional example to work from? Thanks.
einx.param requires different arguments based on the backend that you are using.
In Torch, you always have to define model parameters in the constructor of a module (as nn.parameter.UninitializedParameter if you are using einn), and pass the parameter to einn.param in the forward method (see this). It will be initialized by einx with the appropriate shape:
import torch.nn as nn
import einx.nn.torch as einn
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.w = nn.parameter.UninitializedParameter(dtype=torch.float32)
def forward(self, x):
# initializes self.w with shape (c,) based on x.shape
x = einx.multiply("... [c]", x, einn.param(self.w, init=1e-5))
return x
Other backends like Haiku for example allow defining parameters directly in the forward/__call__ method instead of the constructor. In this case, einn.param will create the parameter internally and associate it with the current module:
import haiku as hk
import einx.nn.haiku as einn
class MyModule(hk.Module):
def __call__(self, x):
# creates new parameter with shape (c,) based on x.shape, and stores it in self
x = einx.multiply("... [c]", x, einn.param(init=1e-5))
return x