einx icon indicating copy to clipboard operation
einx copied to clipboard

Common neural network operations

Open erlebach opened this issue 10 months ago • 1 comments

I am trying the provided examples, particularly those associated with neural networks. In particular, the following example does not seem to work:

x = einx.multiply("... [c]", x, einn.param(init=1e-5))

einn.param is missing a required argument. I could not find a similar example that is standalone. Even with AI, I cannot get the param() example to work. There is also no example in the tests/ folder that works with torch. Could somebody please provide me a functional example to work from? Thanks.

erlebach avatar Feb 08 '25 16:02 erlebach

einx.param requires different arguments based on the backend that you are using.

In Torch, you always have to define model parameters in the constructor of a module (as nn.parameter.UninitializedParameter if you are using einn), and pass the parameter to einn.param in the forward method (see this). It will be initialized by einx with the appropriate shape:

import torch.nn as nn
import einx.nn.torch as einn

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.w = nn.parameter.UninitializedParameter(dtype=torch.float32)

    def forward(self, x):
        # initializes self.w with shape (c,) based on x.shape
        x = einx.multiply("... [c]", x, einn.param(self.w, init=1e-5))
        return x

Other backends like Haiku for example allow defining parameters directly in the forward/__call__ method instead of the constructor. In this case, einn.param will create the parameter internally and associate it with the current module:

import haiku as hk
import einx.nn.haiku as einn

class MyModule(hk.Module):
    def __call__(self, x):
        # creates new parameter with shape (c,) based on x.shape, and stores it in self
        x = einx.multiply("... [c]", x, einn.param(init=1e-5))
        return x

fferflo avatar Feb 13 '25 07:02 fferflo