R-transformer
R-transformer copied to clipboard
Implementation of LocalRNN?
Looks like you only use the last dimension of self.ksize
def forward(self, x):
nbatches, l, input_dim = x.shape
x = self.get_K(x) # b x seq_len x ksize x d_model
batch, l, ksize, d_model = x.shape
h = self.rnn(x.view(-1, self.ksize, d_model))[0][:,-1,:]
#input (batch*l, ksize, feature)
#rnn output&h_n , output ( batch*l,ksize, dmodel)
return h.view(batch, l, d_model)
h = self.rnn(x.view(-1, self.ksize, d_model))[0][:,-1,:]
extracts only the last element, i.e. last dimension of ksize. Have you considered some alternative, like averaging?