returnn
returnn copied to clipboard
`Dim` internals and API should be refactored
The logic involving get_for_batch_ctx and get_same_base is ugly, unintuitive, and a potential source of problems.
One big problem is that we never really defined the logic around get_for_batch_ctx well, and it was added only at a later point. And similarly for some of the other concepts. E.g. the equality of dim tags (#634) is also not well defined.
CumConcatLayer (#589) and generalized self-attention (#391) were one of the main reasons which introduced this, but then also the beam search logic.
CumConcatLayer also introduced the concept of implicit dims.
Defining a dim tag (sequence lengths) inside a loop, and then having it accumulated outside is still somewhat straightforward and non-ambiguous, so this is not really a problem.
Maybe it was a problem to treat them as the same though? But I think this is important such that the rec loop optimization (move out of loop) works correctly.
Note also that the whole logic around get_for_batch_ctx is basically just for having different dyn_size_ext (dyn_size) for the same dim tag, under different conditions, such as inside a loop or outside, and with beam or without.
Direct assignments or reads from dyn_size or dyn_size_ext, but also all the other flags, even description, depends on get_for_batch_ctx or get_same_base.
Some older code ignores get_for_batch_ctx or get_same_base and directly accesses (reads or writes) any of the other flags like dyn_size. Which works when it is the correct instance. But otherwise it can lead to unexpected behavior.
Related is also declare_same_as, although this might not be much a problem, even less after such refactoring. However, its logic currently is quite complicated, and should be simplified.
I'm also not sure about a good way to solve this. Maybe dyn_size and dyn_size_ext should be hidden away, and only be accessible through functions get_... and set_..., which would also require the batch and ctx.
Another big problem is the special role of the batch dim (#920), and all the extra logic around BatchInfo, and when Data.batch or Dim.batch should be set, and what it actually should be set to. In principle, we should not need any special logic for the batch dim, and it should be treated just like other dims.
One feature which is available for the batch dim (BatchInfo) but not for other dims is the flattening logic, and also the special beam search logic.
I think the flattening should be done in a way that you could combine multiple dims (but any dims) and you would get a new flattened dim. Nothing would be specific about the batch dim. There would be meta info attached to the combined dim to be able to recover the individual dims.
The beam should just be another separate dim, not merged into the batch dim. And then there could also be meta information attached to it, what we basically have in SearchBeam.
Related is also the definition of dim tag equality (#634). This is still not well defined in all cases.
A bit related is also dim tag math, and esp questions on equality in those cases. However, I think this was not too much of a problem so far, except that the equality logic was also broken in its own ways in those cases.
Further, there is also the derived_from_tag and derived_from_op logic, which is yet another heuristic for certain equality matching. Maybe this is not needed when dim tags are used everywhere consistently.
And there is also is_dim_known, undefined, which are also not so well defined.
Such changes might break some older code. But on RETURNN side, this can all be fixed. And there should not be much (or any) external code yet using this.
Some examples of issues and resulting PRs caused due to the API of Dim, where things were not really well defined:
- #666
- #865
- #1046
- #1054 and #1055
- #1057 and #1058
- #1069 and #1068
- #1102 and #1104
- #1107
- #1112
- #1114
- #1151
- #1152
- #1167 and #1168
- #1246
Related issues:
- #1153
- #920
- #634
- #589
- #391
Note, in case we try to address this, i.e. clean up or refactor some of this: We should also check and run the test cases of returnn-common, as there are some problems which might only occur via those tests.
One big problem is that we never really defined the logic around get_for_batch_ctx well, and it was added only at a later point. And similarly for some of the other concepts. E.g. the equality of dim tags (#634) is also not well defined.
CumConcatLayer (#589) and generalized self-attention (#391) were one of the main reasons which introduced this, but then also the beam search logic.
CumConcatLayer also introduced the concept of implicit dims.
Defining a dim tag (sequence lengths) inside a loop, and then having it accumulated outside is still somewhat straightforward and non-ambiguous, so this is not really a problem.
Maybe it was a problem to treat them as the same though? But I think this is important such that the rec loop optimization (move out of loop) works correctly.
Note also that the whole logic around get_for_batch_ctx is basically just for having different dyn_size_ext (dyn_size) for the same dim tag, under different conditions, such as inside a loop or outside, and with beam or without.
Another big problem is the special role of the batch dim (#920), and all the extra logic around BatchInfo, and when Data.batch or Dim.batch should be set, and what it actually should be set to. In principle, we should not need any special logic for the batch dim, and it should be treated just like other dims.
One feature which is available for the batch dim (BatchInfo) but not for other dims is the flattening logic, and also the special beam search logic.
I think the flattening should be done in a way that you could combine multiple dims (but any dims) and you would get a new flattened dim. Nothing would be specific about the batch dim. There would be meta info attached to the combined dim to be able to recover the individual dims.
The beam should just be another separate dim, not merged into the batch dim. And then there could also be meta information attached to it, what we basically have in SearchBeam.
Related:
- #1153
- #920
- #634
- #589
- #391
A bit related is also dim tag math, and esp questions on equality in those cases. However, I think this was not too much of a problem so far, except that the equality logic was also broken in its own ways in those cases.
@Zettelkasten any thoughts or feedback on this?
I wonder whether we should maybe avoid the whole get_for_batch_ctx logic completely. To reiterate why we have it:
- Automatically carry around a version with beam (tiled) of the
dyn_size_ext. - Different variant of
dyn_size_extwithin a loop and without. This was introduced forCumConcatLayer(https://github.com/rwth-i6/returnn/pull/589) and generalized self-attention (https://github.com/rwth-i6/returnn/issues/391). The important aspect is that this should work seamlessly with recurrent automatic optimization.
When we have the beam dim always separate and explicit, the first point would not be needed. At least with returnn-common, we probably want to have this anyway.
For the second point, maybe we really want to treat them as separate dim tags, one variant inside the loop, and really a separate dim tag outside the loop. But this will need some logic to get from one variant of the dim tag to the other. So it again gets us to sth similar as get_for_batch_ctx.
Maybe the problem is also more that we treat all those variants returned by get_for_batch_ctx as equal (see definition of equality #634)?
Or maybe the problem is that Data will automatically adapt the dim tags in Data._adapt_batch_consistent_dim_tags? This is because we treat them all as equal.
Maybe treating them as equal is also not so much the problem. What exactly is actually the problem?
Sometimes we don't really know which is the right variant to pick in _adapt_batch_consistent_dim_tags. E.g. when there is no batch dim in the tensor. "Batch consistent" does not really make sense then. It should maybe pick some reasonable then. Is this always defined. What matters more is the context (e.g. within loop or not), which should always be available as information. And if the tensor does not have a batch, it should pick whatever dim tag is available with the simplest batch.
I think we should go through all the attribs and see which of those are really needed.
One specific function I wonder is declare_same_as, whether we really need that. The user can replace some existing dim tag by another dim tag via ReinterpretDataLayer.
When making Data/Dim framework independent, and esp usable for eager mode (#1165), it becomes clear that we need to simplify it, and speed it up a lot. So solving #1165 with this requirement would probably also solve this issue here.
One specific function I wonder is
declare_same_as, whether we really need that.
So, when looking for current use cases in RETURNN, what we have is:
-
In many layers, it will do sth like
out_spatial_dim_.declare_same_as(out_spatial_dim), whereout_spatial_dim_is the automatically calculated dimension based on the input, andout_spatial_dimis passed by the user. The common cases are thatout_spatial_dimis actually yet undefined, and this is the implicit way to define it. Or, it can also be that it was already defined before, but then we expect it to be the same.out_spatial_dim_is usually via dim tag math, but it might also be more dynamic and then only later defined in the layer__init__. How to replace this? Maybe some code like this:if not out_spatial_dim: out_spatial_dim = Dim(<undefined>) if out_spatial_dim.is_not_defined: out_spatial_dim.set(...) # maybe just dyn_size_ext template?That way, we don't need
same_as, as there are not multiple objects referring to the same dim. I'm not sure if we really need the dim tag math or dim derived-from logic here. In__init__, it can calculate and set the actualdyn_size_extvalues, if needed. I'm not sure if, how and where there should be any error checking later whether the dim is actually correct, if it is already defined. -
In returnn_common API, it is unnatural to pass an existing
out_spatial_dimto some function likepool1d, and actually not recommended, and instead, you get a new out spatial dim as return value. Currently, if that is expected to be the same as some other existing dim, it usesdeclare_same_as. Not just for implicit validation, but also to mark it the same, if the dim tags are not equal otherwise. We can avoid this use ofdeclare_same_asby replacing the new dim by the old existing dim tag viaReinterpretDataLayer. -
In some cases, it is an implicit check that two dims are actually the same. It's kind of an assert. E.g. in
RecLayerorRecUnstackLayerfor the time dim. And if one is undefined, it would define it this way. -
There is one use case in
CondLayer, to merge some dynamic dim inside the condition (although this is anyway incomplete). See https://github.com/rwth-i6/returnn_common/commit/ffc5083d04f2cc0f159f73c52cf3421814cd7e80. -
Datasame_dim_tags_as
Related, the get_for_batch_ctx logic, which also uses the same_as logic:
CumConcatLayer, and in general when some dim is defined inside a loop, per iteration, and then accessed outside, it gets a new dimension indyn_size_ext. So, do we really need or want the dim tag outside the loop to be the same as inside the loop? Or does it need a reference to it? Note that it was intended to work both when optimized out of the loop or being inside the loop. And theout_spatial_dimis a necessary arg. And it will be the same arg when optimized out (unless we somehow transform it then). The optimized out case matters when the layer is accessed outside, or also when other follow-up layers access it. They would refer to the same dim then. Those follow-up layers might also end up being inside the loop. So, does this imply that we must have equality, and must havesame_as, and the wholeget_for_batch_ctxlogic? Note that thebatchaspect here is not needed when we don't merge a beam into it or do other things. Note that thectxaspect is not needed for eager-mode, including PyTorch backend.