returnn
returnn copied to clipboard
GatherLayer on batch axis
This PR fixes #1087. As I face the issue in the context of supervised multilingual training, I added a more general test case also for that which does not necessarily need to go into the main branch. The fix is similar to how the size_placeholder is modified in the ShiftAxisLayer.
Hi @albertz, what do you think about the way the size placeholder and dim tag are modified in general? Right now there is a failing test case, where we first do flatten_batch and then gather on the batch axis. I'm not very sure how the desired behavior in this case would look like. It'd be nice if you could comment on what you think here.
When you modify the batch dim, you should create a new BatchInfo object as well, and assign that to output.
When you modify the batch dim, you should create a new
BatchInfoobject as well, and assign that tooutput.
As I said, the fix is similar to what is done in the ShiftAxisLayer. Do you have another layer which modifies the batch axis and could serve as a good example?
When you modify the batch dim, you should create a new
BatchInfoobject as well, and assign that tooutput.As I said, the fix is similar to what is done in the
ShiftAxisLayer. Do you have another layer which modifies the batch axis and could serve as a good example?
ShiftAxisLayer does not modify the batch dim. You probably mean the size adoption. That code is a bit ugly/outdated/deprecated/hacky in ShiftAxisLayer, and might not work correct in all cases (but anyway it's simpler because the batch dim is not changed).
Do you have another layer which modifies the batch axis and could serve as a good example?
Not many layers do that. I just recall FlattenBatchLayer right now.
As discussed offline, it is possible to get the desired results in my use case using the MaskedComputationLayer. Instead of the indices to gather, we need a boolean mask over the batch axis. In my use case, I have this anyway and only computed the indices from the mask. We can use the mask like this:
network = {
"encoder": {...}, # B, T, F
"boolean_mask": {...}, # B
"encoder_masked": {
"class": "masked_computation",
"mask": "boolean_mask",
"unit": {"class": "copy", "from": "encoder"}
}, # B', T, F
...
}
Since that does exactly what I need, I'll close this PR and the corresponding issue.
Well, GatherLayer on batch axis is still maybe sometimes a valid thing someone wants to do. I would leave this PR open.