GraphRNN icon indicating copy to clipboard operation
GraphRNN copied to clipboard

Why does the model pack_padded after linear transform?

Open omega0x16 opened this issue 2 years ago • 0 comments

Hello @JiaxuanYou. I have a question in model.py.

In forward part of GRU_plain, why does it pack_padded_sequence after linear transform by self.input?

Doesn't the non-zero raw get the information of padded 0?

Thanks!

class GRU_plain(nn.Module): 
    def init(self, input_size, embedding_size, hidden_size, num_layers, has_input=True, has_output=False, output_size=None): 
        super(GRU_plain, self).init() 
        self.num_layers = num_layers 
        self.hidden_size = hidden_size 
        self.has_input = has_input 
        self.has_output = has_output
        if has_input:
            self.input = nn.Linear(input_size, embedding_size) 
            self.rnn = nn.GRU(input_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers,
                              batch_first=True) 
        else:
            self.rnn = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        if has_output:
            self.output = nn.Sequential(
                nn.Linear(hidden_size, embedding_size),
                nn.ReLU(),
                nn.Linear(embedding_size, output_size)
            )
    
        self.relu = nn.ReLU()
        # initialize
        self.hidden = None  # need initialize before forward run
    
        for name, param in self.rnn.named_parameters():
            if 'bias' in name:
                nn.init.constant(param, 0.25)
            elif 'weight' in name:
                nn.init.xavier_uniform(param,gain=nn.init.calculate_gain('sigmoid'))
        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
    
    def init_hidden(self, batch_size): 
        return Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size)).cuda()
    
    def forward(self, input_raw, pack=False, input_len=None):
        if self.has_input:
            input = self.input(input_raw) 
            input = self.relu(input)
        else:
            input = input_raw
        if pack:
            input = pack_padded_sequence(input, input_len, batch_first=True) 
        output_raw, self.hidden = self.rnn(input, self.hidden)
        if pack:
            output_raw = pad_packed_sequence(output_raw, batch_first=True)[0]
        if self.has_output:
        output_raw = self.output(output_raw)
    # return hidden state at each time step
    return output_raw

omega0x16 avatar Feb 08 '23 19:02 omega0x16