TinyNeuralNetwork icon indicating copy to clipboard operation
TinyNeuralNetwork copied to clipboard

[converter] implement torch's `aten::scaled_dot_product_attention` operator

Open mjamroz opened this issue 8 months ago • 2 comments

Is there any chance to implement torch aten::scaled_dot_product_attention?

https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html says it could be done as

# Efficient implementation equivalent to the following:
scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale
attn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) if is_causal else attn_mask
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask
attn_weight = torch.softmax((Q @ K.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p)
return attn_weight @ V

mjamroz avatar Oct 28 '23 09:10 mjamroz