TinyNeuralNetwork icon indicating copy to clipboard operation
TinyNeuralNetwork copied to clipboard

[converter] implement torch's `aten::scaled_dot_product_attention` operator

Open mjamroz opened this issue 1 year ago • 2 comments

Is there any chance to implement torch aten::scaled_dot_product_attention?

https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html says it could be done as

# Efficient implementation equivalent to the following:
scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale
attn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) if is_causal else attn_mask
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask
attn_weight = torch.softmax((Q @ K.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p)
return attn_weight @ V

mjamroz avatar Oct 28 '23 09:10 mjamroz

Before it is implemented, I guess it is easier if you replace F.scaled_dot_product_attention with the given implementation.

peterjc123 avatar Oct 28 '23 14:10 peterjc123

In order to get it supported, we will need to support the new TorchScript schema first. Also, we should probably figure out better way to reuse the conversation logic of the previously-supported ops.

peterjc123 avatar Oct 28 '23 14:10 peterjc123