flax icon indicating copy to clipboard operation
flax copied to clipboard

Make redundant `features` argument optional for recurrent cells

Open carlosgmartin opened this issue 11 months ago • 4 comments

For recurrent cells such as the following:

the features argument of the constructor is redundant: It can be inferred from the carry input to its __call__ method. (The only cell that currently uses self.features in its __call__ method is ConvLSTMCell, which ought to be modified to infer it from its carry input.)

For each cell, the only place where self.features is needed is in the initialize_carry method. But in many models, the initial carry comes from "upstream" in the model, so this method is never used.

Proposal:

  1. Edit ConvLSTMCell to infer features in its __call__ method from its carry input.

  2. Set features=None by default in each cell's constructor.

  3. Add the following line to each initialize_carry method:

assert self.features is not None, "features cannot be None when calling initialize_carry"

I can submit a PR for this, if desired.

An alternative would be to pass features directly to the initialize_carry method.

carlosgmartin avatar Feb 25 '24 20:02 carlosgmartin

This feature is needed for RNN I think. I think we added them in just for this 😅 Also, it feels more natural to specify hparams explicitly in the constructor.

cgarciae avatar Mar 06 '24 18:03 cgarciae

@cgarciae Isn't shape inference from inputs, as done for the inputs argument, more in line with flax's init philosophy?

The RNN situation could be resolved as follows:

  1. Add a features argument to the cell's initialize_carry method.
  2. Add a features argument to RNN's constructor, and on this line, pass it to the self.cell.initialize_carry call.

That seems more natural and elegant to me, since the number of features may ultimately be determined by stuff upstream in the model (as opposed to being intrinsic to the cell itself).

carlosgmartin avatar Mar 08 '24 01:03 carlosgmartin

What you are describing is how the Flax recurrent API was before, however it was a bit inconsistent e.g. some classes like ConvLSTM required passing the output features while others did not, and it also lacked some of the structure needed to implement the RNN class in simple terms. The solution was to add features to all RNN layers and slightly simplify initialize_carry.

cgarciae avatar Apr 18 '24 13:04 cgarciae

@cgarciae Hmm, is there any reason ConvLSTM can't infer features from its inputs, like the other recurrent modules? I submitted a PR to address that here.

carlosgmartin avatar Apr 18 '24 16:04 carlosgmartin