returnn
returnn copied to clipboard
RF scaled_dot_product_attention
Add scaled_dot_product_attention as a function to RF, and use it in our attention code. (Does this also work with RelPosSelfAttention?)
In case of PyTorch, wrap torch.nn.functional.scaled_dot_product_attention. That should be much more compute and memory efficient compared to the direct implementation. It uses FlashAttention or potentially a number of other efficient kernels. (Although on older GPUs, probably not. See my question on PyTorch discussion forum.)
(cc @NeoLegends @dorian-K) (If anyone plans to work on this, please say so.)
I'm thinking about the implications of this right now. Do we deal with ONNX export of RF models at all? How would using F.scaled_dot_product_attention affect this? Does it need an equivalent replacement path using "traditional ops" during export? In i6_models code the MHSA implementation w/o positional encoding uses pytorch's optimized implementation, but I still have to check if we do any model surgery (and replace it w/ a traditional impl) before export.
We can simply use a simple fallback implementation when this is done for ONNX. I think we already have that in some other places? E.g. in TorchBackend.full:
if torch.onnx.is_in_onnx_export():
# onnx::ConstantOfShape (via torch.full) must get shape as int64.
# https://github.com/rwth-i6/returnn/issues/1333#issuecomment-1607236783
shape = [dim.long() if isinstance(dim, torch.Tensor) else dim for dim in shape]
Also in TorchBackend.conv.
I'm working on this now.
The way causal self attention is currently implemented is somewhat problematic, as the spatial dimension for the key, value matrices depends on the spatial dimension of the query matrix (See https://github.com/rwth-i6/returnn/blob/9678032af5228d7e837371cb641439f8932ec6e7/returnn/frontend/attention.py#L249). I.e. when the key and value matrices are passed to the dot attention a dimension of those matrices is not well-defined, and it only becomes valid once multiplied with the query matrix.
Of course we can change the implementation (and we have to, to utilize the optimized causal attention kernels in pytorch) but there are usages of rf.dot_attention in i6_experiments that depend on this behaviour. In these cases we still need to fall back to the old implementation, which we anyways need to keep to support non-pytorch backends.
The pytorch implementation also does not support broadcasting dropout. So maybe we just raise an error when it is set in this case? Or we just remove it entirely
torch.nn.functional.scaled_dot_product_attention is quite a bit more general than our current dot_attention, for example it allows for a different number of query heads than key/value heads (and similar with query/key and value embed dimensions). For now I will only add the scale and attn_mask parameters.
Another interesting pytorch feature is FlexAttention, which allows for a much greater customization of the attention mechanism. This can be used to implement i.e. sliding window attention or packed/jagged tensors. This is very useful for research, but I don't think this fits well into returnn atm because pytorch is the only framework that implements this API.
I'm currently limited in time and not following the full argumentation. But if the existing API of rf.dot_attention does not really fit scaled_dot_product_attention, I don't think this is a problem: Let's just make a separate RF function, specifically for scaled_dot_product_attention. (But let me later understand why it does not fit.)
For FlexAttention, it's probably again similar, and we might need yet another separate RF function.
The way causal self attention is currently implemented is somewhat problematic, as the spatial dimension for the key, value matrices depends on the spatial dimension of the query matrix
You speak about the case axis == single_step_dim or when executed on the full sequence? I assume the latter (axis != single_step_dim).
I don't understand why this is problematic?
I.e. when the key and value matrices are passed to the dot attention a dimension of those matrices is not well-defined, and it only becomes valid once multiplied with the query matrix.
Which dimension is not well-defined? Why?
You are specifically speaking about rf.CausalSelfAttention? Or rf.dot_attention in general?
Of course we can change the implementation (and we have to, to utilize the optimized causal attention kernels in pytorch) but there are usages of
rf.dot_attentionin i6_experiments that depend on this behaviour.
On what behavior exactly?
I don't really understand the problem yet.
Look at this code snippet: https://github.com/rwth-i6/returnn/blob/9678032af5228d7e837371cb641439f8932ec6e7/returnn/frontend/attention.py#L249-L251
Before this is executed, axis is the (shared) spatial dimension of both query and key/value.
But for the key and value matrices, this spatial dimension is then replaced with hist_dim, and its size tensor depends on the dimension axis, which is no longer present because it has been replaced by hist_dim in the key and value matrices. So those matrices are not well-defined without the axis dimension which is only present in the query matrix. So this is a problem in rf.CausalSelfAttention.
This is not a problem with the current rf.dot_attention because the masked dimension doesn't matter in the query @ key matmul, and after that the result tensor has both the hist_dim and axis dimensions, and everything is well defined again. But if we want to use pytorch for the whole attention operation, we would need to convert these dimensions to a mask matrix before the matmul happens. But now, in theory, both the query and key spatial dimensions could have masking, this makes the logic confusing.
Honestly I'm not sure if I am understanding this all correctly myself.
But what we actually want to do, is use the is_causal parameter of the torch scaled_dot_product_attention, which reduces the computational time by 2x because all those unnecessary computations that are masked out anyway do not happen anymore. So we need to change the CausalSelfAttention code anyways.
Another issue is that some tests, and some code in i6_experiments for analyzing attention weights relies on the pre-softmax energies variable being present, as those code snippets use the python tracer. This obviously won't be possible anymore once we use the torch implementation. So maybe there should be some parameter which can be set that disables the pytorch implementation.
We could also immediately disable the pytorch implementation once we see that there is a tracer present (sys.gettrace() is not None), but I'm concerned that this is quite unintuitive behaviour. I.e. in the future there may be a bug which magically disappears once you attach a python debugger that uses python tracing..
Another issue is that some tests, and some code in i6_experiments for analyzing attention weights relies on the pre-softmax energies variable being present, as those code snippets use the python tracer. This obviously won't be possible anymore once we use the torch implementation. So maybe there should be some parameter which can be set that disables the pytorch implementation.
That is easy to solve. There can be some env var, and/or flag in the global config, or some other global flag to disable the optimized case. I would not worry about that. When the user forgets about it, there will be an error and then it's clear how to solve it.
Before this is executed,
axisis the (shared) spatial dimension of both query and key/value. But for the key and value matrices, this spatial dimension is then replaced withhist_dim, and its size tensor depends on the dimensionaxis, which is no longer present because it has been replaced byhist_dimin the key and value matrices. So those matrices are not well-defined without theaxisdimension which is only present in the query matrix. So this is a problem in rf.CausalSelfAttention.
I still don't understand. What do you mean by "those matrices are not well-defined without the axis dimension"? The dimension (hist_dim) is fully defined, or not? Its size is also already calculated, or not?
Is there some specific exception you get? (Or expect to get?)
But what we actually want to do, is use the
is_causalparameter of the torchscaled_dot_product_attention, which reduces the computational time by 2x because all those unnecessary computations that are masked out anyway do not happen anymore. So we need to change the CausalSelfAttention code anyways.
Yes. We can also add an is_causal argument for dot_attention, and maybe also as a class attrib of SelfAttentionBase, and then in SelfAttentionBase.attention, passing is_causal=self.is_causal to dot_attention:
def attention(self, q: Tensor, k: Tensor, v: Tensor, *, kv_axis: Dim) -> Tensor:
"""apply attention"""
att = dot_attention(
q,
k,
v,
key_dim=self.key_dim_per_head,
axis=kv_axis,
att_dropout=self.att_dropout,
att_dropout_broadcast=self.att_dropout_broadcast,
is_causal=self.is_causal,
)
output, _ = rf.merge_dims(att, dims=(self.num_heads, self.value_dim_per_head), out_dim=self.value_dim_total)
if self.proj:
output = self.proj(output)
return output
I still don't understand. What do you mean by "those matrices are not well-defined without the
axisdimension"? The dimension (hist_dim) is fully defined, or not? Its size is also already calculated, or not?Is there some specific exception you get? (Or expect to get?)
The key matrix will have dimensions [batch, ..., hist_dim, embed_dim]. The hist_dim dimension has a size tensor with dimensions [axis]. But axis is not a dimension of the key matrix. So if you want to know if a certain element is part of the key matrix, or masked out, this is not possible because the hist_dim cannot be interpreted without an index for the axis dimension, but we do not have an index for the axis dimension because the key matrix does not have this dimension
The key matrix will have dimensions
[batch, ..., hist_dim, embed_dim]. Thehist_dimdimension has a size tensor with dimensions[axis]. Butaxisis not a dimension of the key matrix. So if you want to know if a certain element is part of the key matrix, or masked out, this is not possible because thehist_dimcannot be interpreted without an index for theaxisdimension, but we do not have an index for theaxisdimension because the key matrix does not have this dimension
Yes, axis is not a dim of the key matrix. So the masking is not even properly defined. But it becomes defined when you add it. It actually must be added to properly apply the masking - without having it, you anyway cannot apply masking in a meaningful way. But actually you would never really need to apply masking on the key matrix.
You multiply the keys with the queries (code from dot_attention, adapted to your dim names):
energy = rf.matmul(query, keys, reduce=embed_dim)
# energy: [batch, ..., hist_dim, axis]
att_weights = rf.softmax(energy, axis=hist_dim)
There is where you get back the axis dim. It comes from the matmul with the queries. And the masking only need to be applied on the energies (energy here). You never need to apply it to keys.
I still don't understand how this is now relevant for torch.nn.functional.scaled_dot_product_attention. I guess you want to know how to set attn_mask and/or maybe also is_causal? It should be:
attn_mask = rf.compare_bc(rf.range_over_dim(hist_dim, device=query.device), "<", hist_dim.get_size_tensor(device=query.device))
# attn_mask: [axis, hist_dim]
(Note, in dot_attention, the dim is actually called axis, what we call hist_dim here.)
For is_causal, you could write code that checks whether the attention mask is a square matrix and a lower triangular matrix. I'm not sure though that such a check is necessary, when you provide the attn_mask. I think it anyway only computes the attention where necessary based on the mask?