flax
flax copied to clipboard
Make redundant `features` argument optional for recurrent cells
For recurrent cells such as the following:
the features
argument of the constructor is redundant: It can be inferred from the carry
input to its __call__
method. (The only cell that currently uses self.features
in its __call__
method is ConvLSTMCell
, which ought to be modified to infer it from its carry
input.)
For each cell, the only place where self.features
is needed is in the initialize_carry
method. But in many models, the initial carry comes from "upstream" in the model, so this method is never used.
Proposal:
-
Edit
ConvLSTMCell
to inferfeatures
in its__call__
method from itscarry
input. -
Set
features=None
by default in each cell's constructor. -
Add the following line to each
initialize_carry
method:
assert self.features is not None, "features cannot be None when calling initialize_carry"
I can submit a PR for this, if desired.
An alternative would be to pass features
directly to the initialize_carry
method.
This feature is needed for RNN
I think. I think we added them in just for this 😅
Also, it feels more natural to specify hparams explicitly in the constructor.
@cgarciae Isn't shape inference from inputs, as done for the inputs
argument, more in line with flax's init
philosophy?
The RNN
situation could be resolved as follows:
- Add a
features
argument to the cell'sinitialize_carry
method. - Add a
features
argument toRNN
's constructor, and on this line, pass it to theself.cell.initialize_carry
call.
That seems more natural and elegant to me, since the number of features
may ultimately be determined by stuff upstream in the model (as opposed to being intrinsic to the cell itself).
What you are describing is how the Flax recurrent API was before, however it was a bit inconsistent e.g. some classes like ConvLSTM
required passing the output features while others did not, and it also lacked some of the structure needed to implement the RNN
class in simple terms. The solution was to add features
to all RNN layers and slightly simplify initialize_carry
.
@cgarciae Hmm, is there any reason ConvLSTM
can't infer features
from its inputs, like the other recurrent modules? I submitted a PR to address that here.