pytorch-MemNet icon indicating copy to clipboard operation
pytorch-MemNet copied to clipboard

Parameters are much more than official version, why?

Open Achhhe opened this issue 6 years ago • 4 comments

the recursive_block parameters are not shared?

Achhhe avatar Aug 20 '19 03:08 Achhhe

Yes, not shared, but it would be easy to modify to the shared version

Vandermode avatar Aug 20 '19 09:08 Vandermode

Yes, not shared, but it would be easy to modify to the shared version

This is my modified version, I'm not sure if it's correct

class MemoryBlock(nn.Module): """Note: num_memblock denotes the number of MemoryBlock currently""" def init(self, channels, num_resblock, num_memblock): super(MemoryBlock, self).init() #self.recursive_unit = nn.ModuleList( #[ResidualBlock(channels) for i in range(num_resblock)] #) self.num_resblock = num_resblock self.recur_block = ResidualBlock(channels) self.gate_unit = BNReLUConv((num_resblock+num_memblock) * channels, channels, 1, 1, 0)

def forward(self, x, ys):
    """ys is a list which contains long-term memory coming from previous memory block
    xs denotes the short-term memory coming from recursive unit
    """
    xs = []
    residual = x
    #for layer in self.recursive_unit:
        #x = layer(x)
        #xs.append(x)
    for i in range(self.num_resblock):
        x = self.recur_block(x)
        xs.append(x)
    gate_out = self.gate_unit(torch.cat(xs+ys, 1))
    ys.append(gate_out)
    return gate_out

Achhhe avatar Aug 20 '19 11:08 Achhhe

The implementation seems to be correct. However, to my knowledge, this parameter-shared version might lead to performance decline.

Vandermode avatar Aug 21 '19 09:08 Vandermode

The implementation seems to be correct. However, to my knowledge, this parameter-shared version might lead to performance decline.

however, if don't do like this, what the meaning of recursive unit in paper?

fantasysponge avatar Sep 23 '20 06:09 fantasysponge