returnn
returnn copied to clipboard
Initial state cannot be taken from base layer
I've run into a bug: If I have a base layer state
and want to use it to initialize an lstm that's used within the recurrent beam search, the base layer state
is not wrappend by an ExtendWithBeamLayer
. Therefore, the batch dimension does not match, and the search aborts.
I've created a simple test case for this in my PR #189. But I'm not sure how to fix it.
As far as I can tell, this issue arises, because the initial states of the tf.while_loop
are collected, before the _SubnetworkRecCell
layer is constructed (which is where the ExtendWithBeamLayer
would be added).
But since tf.while_loop
depends on the initial states, this cannot easily be swapped.
Any ideas? If yes, I can implement that and complete the PR.
Without looking at the test or the code: At the time we construct the initial state, we already know the beam size, and we pass it to that function get_rec_initial_output
. Maybe we don't correctly handle the case when we get a layer there.
I'm at a loss here. Currently, the only place where an ExtendWithBeamLayer
layer is created, is from within the get_layer
method of the _SubnetworkRecCell
. But since this layer is only created inside of the tf.while_loop
call, this is too late for the collection of initial states. I'm not sure whether the _SubnetworkRecCell
layer should be created earlier, or whether the creation of the ExtendWithBeamLayer
should be moved somewhere else. And in any case, I wouldn't know how to do that.
Could you give me some more pointers?
The initial states/outputs are already including the beam. That happens e.g. in get_rec_initial_output
. There you have already the batch with beam size.
For reference, the failure can be seen here (but I did not really check the test case yet...).
I wonder whether this is still a bug or already works now. Someone should check.