Trans2D
Trans2D copied to clipboard
question about code
in your paper ,here should be scaled dot product while element-wise multiplication in your code
full_attn = (query_layer.unsqueeze(-1).unsqueeze(-1) * key_layer.unsqueeze(2).unsqueeze(2).transpose(4, 6)).sum(4)
time_attn = (query_layer.unsqueeze(3) * key_layer.unsqueeze(2)).sum([4, 5]).unsqueeze(-2).unsqueeze(-2)
feature_attn = (query_layer.unsqueeze(4) * key_layer.unsqueeze(3)).sum([2, 5]).unsqueeze(2).unsqueeze(-1)
attention_scores = self.alpha[0] * full_attn + self.alpha[1] * feature_attn + self.alpha[2] * time_attn
so how to describle this code snip
This is scaled dot product if you look carefully. We use broadcasting techniques with summation to achive the same desired effect as dot product.