equinox
equinox copied to clipboard
Question regarding composable `LSTMCells`
Hi Patrick!
Thanks for the great work with the library. I'm just getting started out in JAX after years of PyTorch and have a question regarding composable LSTMCells
in equinox. In PyTorch, I'm used to doing this:
...
# model definition
def forward(self, x, state):
h1, c1 = self.lstm1(x, state[0][0], state[0][1])
h2, c2 = self.lstm2(h1, state[1][0], state[1][1])
return self.linear(h2)
I was trying to write this in Equinox and got stuck while implementing the input function to jax.lax.scan
. In this tutorial, the way to compute output is given by:
def __call__(self, input ,state):
def f(carry, input):
return self.cell(input, carry), None
out, _ = jax.lax.scan(f, carry, input)
My confusion is that for my definition would we need to do:
def __call__(self, input ,state):
def f1(carry, input):
return self.cell1(input, carry), None
def f2(carry, input):
return self.cell2(input, carry), None
out, _ = jax.lax.scan(f1, state, input)
out2, _ = jax.lax.scan(f2, state, out)
Or is there a better way to go about this?