returnn
returnn copied to clipboard
`rf.RelPosCausalSelfAttention` fails with `single_step_dim`
Hi,
I'm having a problem with rf.RelPosCausalSelfAttention when using it in a transformer decoder. It fails because it wants to remove single_step_dim from a tensor that does not have it in the function _rel_pos_enc_shift here: https://github.com/rwth-i6/returnn/blob/23d666ccf3ac9e748fce4e0d27afe353133eca48/returnn/frontend/attention.py#L412
The input: matrix_bd looks like this: Tensor{'dot', ['initial-beam'(1),B?,'num_heads'(8),'self_att_expand_dim_init+1'(1)]}
The error i get looks like this.
line: matrix_bd = _rel_pos_enc_shift(matrix_bd, axis, pos_emb_spatial_dim, hist_dim)
locals:
matrix_bd = <local> Tensor{'dot', ['initial-beam'(1),B?,'num_heads'(8),'self_att_expand_dim_init+1'(1)]}
_rel_pos_enc_shift = <global> <function _rel_pos_enc_shift at 0x7f78c8937ac0>
axis = <local> Dim{'single-step'!}
pos_emb_spatial_dim = <local> Dim{'self_att_expand_dim_init+1'(1)}
hist_dim = <local> Dim{'self_att_expand_dim_init+1'(1)}
File "returnn/returnn/frontend/attention.py", line 412, in _rel_pos_enc_shift
line: batch_dims = x.remaining_dims((axis, pos_emb_spatial_dim))
locals:
batch_dims = <not found>
x = <local> Tensor{'dot', ['initial-beam'(1),B?,'num_heads'(8),'self_att_expand_dim_init+1'(1)]}
x.remaining_dims = <local> <bound method _TensorMixin.remaining_dims of Tensor{'dot', ['initial-beam'(1),B?,'num_heads'(8),'self_att_expand_dim_init+1'(1)]}>
axis = <local> Dim{'single-step'!}
pos_emb_spatial_dim = <local> Dim{'self_att_expand_dim_init+1'(1)}
File "returnn/returnn/tensor/_tensor_extra.py", line 1849, in _TensorMixin.remaining_dims
line: batch_dims.remove(remove_)
locals:
batch_dims = <local> [Dim{'initial-beam'(1)}, Dim{B}, Dim{'num_heads'(8)}, Dim{'self_att_expand_dim_init+1'(1)}]
batch_dims.remove = <local> <built-in method remove of list object at 0x7f7811ab6900>
remove_ = <local> Dim{'single-step'!}
ValueError: list.remove(x): x not in list
I don't have an easy setup yet for you to reproduce this. However I think it should be easily reproducible when using rf.RelPosCausalSelfAttention with single_step_dim.
I also need to look deeper into the functionality behind this in order to understand what the correct behaviour would be.
If I have any new information on this I will post it here.
I think it's just not implemented yet.
What's the state here? @LucaG1 do you have a fix for this? I thought you are already using this?
Right sorry. I forgot to post it here. For me currently this fix is working. But I am still not sure if this is the correct way to do this.
if axis == single_step_dim:
matrix_bd = rf.expand_dim(matrix_bd, axis)
matrix_bd = _rel_pos_enc_shift(matrix_bd, axis, pos_emb_spatial_dim, hist_dim)
if axis == single_step_dim:
matrix_bd = rf.squeeze(matrix_bd, axis)
So just adding and removing single_step_dim for the call of _rel_pos_enc_shift and hoping that it does the right thing for that case as well.
relative_positional_encoding needs a proper query_offset in case of single step, or not?
Also, rf.expand_dim(matrix_bd, single_step_dim) does not make sense. I wonder that even works? That should throw an exception. single_step_dim is not allowed to be part of the shape of an actual tensor.
I checked and I think it does not need any query offset. In my case the _rel_pos_enc_shift function does not affect matrix_bd anymore.
I guess the correct thing to do would then be
if axis != single_step_dim:
matrix_bd = _rel_pos_enc_shift(matrix_bd, axis, pos_emb_spatial_dim, hist_dim)
If you want I can push this fix.
I checked and I think it does not need any query offset.
Why? That sounds incorrect. Surely a (rel or abs) positional encoding must somehow depend on the position?
It would be good if we also have a test case where we operate on the whole seq in one case, and then operate step-by-step, and then check that we get exactly the same output.
Why? That sounds incorrect. Surely a (rel or abs) positional encoding must somehow depend on the position?
Ah sorry, my bad. I was thinking of something else. Seems to me query_offset is computed automatically here:
https://github.com/rwth-i6/returnn/blob/61ad52a72916d5834a211ea11a8536388a0d7864/returnn/frontend/attention.py#L762
for the case of single_step_dim
This should be fixed now.