returnn icon indicating copy to clipboard operation
returnn copied to clipboard

GatherLayer on batch axis

Open vieting opened this issue 3 years ago • 7 comments
trafficstars

This PR fixes #1087. As I face the issue in the context of supervised multilingual training, I added a more general test case also for that which does not necessarily need to go into the main branch. The fix is similar to how the size_placeholder is modified in the ShiftAxisLayer.

vieting avatar Aug 04 '22 09:08 vieting

Hi @albertz, what do you think about the way the size placeholder and dim tag are modified in general? Right now there is a failing test case, where we first do flatten_batch and then gather on the batch axis. I'm not very sure how the desired behavior in this case would look like. It'd be nice if you could comment on what you think here.

vieting avatar Aug 10 '22 12:08 vieting

When you modify the batch dim, you should create a new BatchInfo object as well, and assign that to output.

albertz avatar Aug 10 '22 12:08 albertz

When you modify the batch dim, you should create a new BatchInfo object as well, and assign that to output.

As I said, the fix is similar to what is done in the ShiftAxisLayer. Do you have another layer which modifies the batch axis and could serve as a good example?

vieting avatar Aug 10 '22 13:08 vieting

When you modify the batch dim, you should create a new BatchInfo object as well, and assign that to output.

As I said, the fix is similar to what is done in the ShiftAxisLayer. Do you have another layer which modifies the batch axis and could serve as a good example?

ShiftAxisLayer does not modify the batch dim. You probably mean the size adoption. That code is a bit ugly/outdated/deprecated/hacky in ShiftAxisLayer, and might not work correct in all cases (but anyway it's simpler because the batch dim is not changed).

albertz avatar Aug 10 '22 14:08 albertz

Do you have another layer which modifies the batch axis and could serve as a good example?

Not many layers do that. I just recall FlattenBatchLayer right now.

albertz avatar Aug 10 '22 14:08 albertz

As discussed offline, it is possible to get the desired results in my use case using the MaskedComputationLayer. Instead of the indices to gather, we need a boolean mask over the batch axis. In my use case, I have this anyway and only computed the indices from the mask. We can use the mask like this:

network = {
    "encoder": {...},  # B, T, F
    "boolean_mask": {...},  # B
    "encoder_masked": {
        "class": "masked_computation",
        "mask": "boolean_mask",
        "unit": {"class": "copy", "from": "encoder"}
    },  # B', T, F
    ...
}

Since that does exactly what I need, I'll close this PR and the corresponding issue.

vieting avatar Aug 30 '22 07:08 vieting

Well, GatherLayer on batch axis is still maybe sometimes a valid thing someone wants to do. I would leave this PR open.

albertz avatar Aug 30 '22 07:08 albertz