returnn
returnn copied to clipboard
Initial state cannot be taken from base layer
I've run into a bug: If I have a base layer state and want to use it to initialize an lstm that's used within the recurrent beam search, the base layer state is not wrappend by an ExtendWithBeamLayer. Therefore, the batch dimension does not match, and the search aborts.
I've created a simple test case for this in my PR #189. But I'm not sure how to fix it.
As far as I can tell, this issue arises, because the initial states of the tf.while_loop are collected, before the _SubnetworkRecCell layer is constructed (which is where the ExtendWithBeamLayer would be added).
But since tf.while_loop depends on the initial states, this cannot easily be swapped.
Any ideas? If yes, I can implement that and complete the PR.
Without looking at the test or the code: At the time we construct the initial state, we already know the beam size, and we pass it to that function get_rec_initial_output. Maybe we don't correctly handle the case when we get a layer there.
I'm at a loss here. Currently, the only place where an ExtendWithBeamLayer layer is created, is from within the get_layer method of the _SubnetworkRecCell. But since this layer is only created inside of the tf.while_loop call, this is too late for the collection of initial states. I'm not sure whether the _SubnetworkRecCell layer should be created earlier, or whether the creation of the ExtendWithBeamLayer should be moved somewhere else. And in any case, I wouldn't know how to do that.
Could you give me some more pointers?
The initial states/outputs are already including the beam. That happens e.g. in get_rec_initial_output. There you have already the batch with beam size.
For reference, the failure can be seen here (but I did not really check the test case yet...).
I wonder whether this is still a bug or already works now. Someone should check.