Trans2D icon indicating copy to clipboard operation
Trans2D copied to clipboard

question about code

Open bb67ao opened this issue 1 year ago • 1 comments

in your paper ,here should be scaled dot product while element-wise multiplication in your code

full_attn = (query_layer.unsqueeze(-1).unsqueeze(-1) * key_layer.unsqueeze(2).unsqueeze(2).transpose(4, 6)).sum(4)
time_attn = (query_layer.unsqueeze(3) * key_layer.unsqueeze(2)).sum([4, 5]).unsqueeze(-2).unsqueeze(-2)
feature_attn = (query_layer.unsqueeze(4) * key_layer.unsqueeze(3)).sum([2, 5]).unsqueeze(2).unsqueeze(-1)
attention_scores = self.alpha[0] * full_attn + self.alpha[1] * feature_attn + self.alpha[2] * time_attn

so how to describle this code snip

bb67ao avatar Nov 26 '22 14:11 bb67ao