flax
flax copied to clipboard
Make redundant `features` argument optional for recurrent cells
For recurrent cells such as the following:
the features argument of the constructor is redundant: It can be inferred from the carry input to its __call__ method. (The only cell that currently uses self.features in its __call__ method is ConvLSTMCell, which ought to be modified to infer it from its carry input.)
For each cell, the only place where self.features is needed is in the initialize_carry method. But in many models, the initial carry comes from "upstream" in the model, so this method is never used.
Proposal:
-
Edit
ConvLSTMCellto inferfeaturesin its__call__method from itscarryinput. -
Set
features=Noneby default in each cell's constructor. -
Add the following line to each
initialize_carrymethod:
assert self.features is not None, "features cannot be None when calling initialize_carry"
I can submit a PR for this, if desired.
An alternative would be to pass features directly to the initialize_carry method.
This feature is needed for RNN I think. I think we added them in just for this 😅
Also, it feels more natural to specify hparams explicitly in the constructor.
@cgarciae Isn't shape inference from inputs, as done for the inputs argument, more in line with flax's init philosophy?
The RNN situation could be resolved as follows:
- Add a
featuresargument to the cell'sinitialize_carrymethod. - Add a
featuresargument toRNN's constructor, and on this line, pass it to theself.cell.initialize_carrycall.
That seems more natural and elegant to me, since the number of features may ultimately be determined by stuff upstream in the model (as opposed to being intrinsic to the cell itself).
What you are describing is how the Flax recurrent API was before, however it was a bit inconsistent e.g. some classes like ConvLSTM required passing the output features while others did not, and it also lacked some of the structure needed to implement the RNN class in simple terms. The solution was to add features to all RNN layers and slightly simplify initialize_carry.
@cgarciae Hmm, is there any reason ConvLSTM can't infer features from its inputs, like the other recurrent modules? I submitted a PR to address that here.