equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Question regarding composable `LSTMCells`

Open aicaffeinelife opened this issue 6 months ago • 4 comments

Hi Patrick!

Thanks for the great work with the library. I'm just getting started out in JAX after years of PyTorch and have a question regarding composable LSTMCells in equinox. In PyTorch, I'm used to doing this:

...
# model definition
def forward(self, x, state):
   h1, c1 = self.lstm1(x, state[0][0], state[0][1])
   h2, c2 = self.lstm2(h1, state[1][0], state[1][1])
   return self.linear(h2) 

I was trying to write this in Equinox and got stuck while implementing the input function to jax.lax.scan. In this tutorial, the way to compute output is given by:

  def __call__(self, input ,state):
      def f(carry, input):
          return self.cell(input, carry), None
      out, _ = jax.lax.scan(f, carry, input)

My confusion is that for my definition would we need to do:

def __call__(self, input ,state):
      def f1(carry, input):
          return self.cell1(input, carry), None
      def f2(carry, input):
          return self.cell2(input, carry), None
      out, _ = jax.lax.scan(f1, state, input)
      out2, _ = jax.lax.scan(f2, state, out)
        

Or is there a better way to go about this?

aicaffeinelife avatar Dec 11 '23 18:12 aicaffeinelife