TinyNeuralNetwork
TinyNeuralNetwork copied to clipboard
[converter] implement torch's `aten::scaled_dot_product_attention` operator
Is there any chance to implement torch aten::scaled_dot_product_attention
?
https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html says it could be done as
# Efficient implementation equivalent to the following:
scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale
attn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) if is_causal else attn_mask
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask
attn_weight = torch.softmax((Q @ K.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p)
return attn_weight @ V
Before it is implemented, I guess it is easier if you replace F.scaled_dot_product_attention with the given implementation.
In order to get it supported, we will need to support the new TorchScript schema first. Also, we should probably figure out better way to reuse the conversation logic of the previously-supported ops.