FlowQA
FlowQA copied to clipboard
[Flow operation] Equivalent code using permutation
I noticed that you used many transposes rather than one permutation in the flow operation, is that for better performance?
I think the following code is equivalent to the original code,
def flow_operation(cur_h, flow):
n, t, c = x1_full.size()
# flow_in = cur_h.transpose(0, 1).view(c, n, t, -1)
# flow_in = flow_in.transpose(0, 2).contiguous().view(t, n * c, -1).transpose(0, 1)
flow_in = cur_h.view(n, t, c, -1).permute(0, 2, 1, 3).contiguous().view(n * c, t, -1)
# [bsz * context_length, max_qa_pair, hidden_state]
flow_out = flow(flow_in)
# [bsz * context_length, max_qa_pair, flow_hidden_state_dim (hidden_state/2)]
if self.opt['no_dialog_flow']:
flow_out = flow_out * 0
# flow_out = flow_out.transpose(0, 1).view(t, n, c, -1).transpose(0, 2).contiguous()
# flow_out = flow_out.view(c, n * t, -1).transpose(0, 1)
# [bsz * max_qa_pair, context_length, flow_hidden_state_dim]
flow_out = flow_out.view(n, c, t, -1).permute(0, 2, 1, 3).contiguous().view(n * t, c, -1)
return flow_out