R-transformer icon indicating copy to clipboard operation
R-transformer copied to clipboard

Implementation of LocalRNN?

Open zjplab opened this issue 5 years ago • 0 comments

Looks like you only use the last dimension of self.ksize

    def forward(self, x):
        nbatches, l, input_dim = x.shape
        x = self.get_K(x) # b x seq_len x ksize x d_model
        batch, l, ksize, d_model = x.shape
        h = self.rnn(x.view(-1, self.ksize, d_model))[0][:,-1,:]
        #input (batch*l, ksize, feature)
        #rnn output&h_n , output ( batch*l,ksize, dmodel)
        return h.view(batch, l, d_model)
h = self.rnn(x.view(-1, self.ksize, d_model))[0][:,-1,:]

extracts only the last element, i.e. last dimension of ksize. Have you considered some alternative, like averaging?

zjplab avatar Jan 20 '20 16:01 zjplab