returnn icon indicating copy to clipboard operation
returnn copied to clipboard

Batch dim tag special handling can be problematic

Open albertz opened this issue 3 years ago • 4 comments

For the contrastive loss implementation (#918), we flatten the masked encoder frames via FlattenBatchLayer and end up with B&Packed{'input_masked_frames:masked:time'} batch dim. For all those frames, we want to create a fixed number K=10 of candidate samples. So the natural way would be to use RandIntLayer and specify shape=[packed_batch_dim, samples_dim] with samples_dim = SpatialDim(..., K). However, that does not work because packed_batch_dim can not be different from the normal batch dim. This is by the current definition of equality of dim tags (#634).

The workaround in #918 is to first change the packed batch dim to a spatial dim via ReinterpretDataLayer, then RandIntLayer works, and later convert it back via ReinterpretDataLayer to a batch dim, and ReinterpretDataLayer got another new option batch_base where the batch dim is taken from. This is ugly obviously.

So I question whether our equality exception for the batch dim (#634) makes sense or maybe should be changed such that it behaves just as any other normal dim tag.

albertz avatar Jan 29 '22 22:01 albertz

@Zettelkasten any thoughts on this?

albertz avatar Jan 29 '22 22:01 albertz

Seems reasonable to me to treat different batch dims (e.g. with different packing format) different. We need to be a bit careful that each Data instance only has at most one batch dim. This also as a consequence means that get_common_data of two inputs with different batch dims would simply fail (currently, I think it would always work). Well, actually, maybe we could make it work always, and let copy_compatible_to also support a "non-packed to packed" conversion, and many other cases that would be needed. That sounds messy and complicates things though, I would rather leave it out and make get_common_data throw an error in case different inputs have different batch dims.

Zettelkasten avatar Jan 29 '22 23:01 Zettelkasten

One reason batch dim was kind of treated always as equal is that in the case it included also the beam, this logic of resolving the beam and making sure all inputs would end up in the same beam was handled separately, so this was a valid assumption.

However, when it contains other things (merge dim with some other dim), or is packed (flatten batch), this is not the case, and probably we also don't want that it automatically makes them equal.

albertz avatar Jan 30 '22 11:01 albertz

Maybe the equality should exactly cover that: Ignore any contained beam but otherwise check BatchInfo equality.

albertz avatar Jan 31 '22 11:01 albertz