IMELE icon indicating copy to clipboard operation
IMELE copied to clipboard

Loading Trained Weights to is_senet=True Model

Open szwiep opened this issue 2 years ago • 2 comments

Good afternoon! Thanks for sharing such exciting work. I've been trying to load in your trained weights to a model defined by the following (is_senet=True):

original_model = senet154(pretrained='imagenet')
 Encoder = E_senet(original_model)
 model = net_model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048])

But PyTorch.load() is failing on some missing keys:

RuntimeError: Error(s) in loading state_dict for net_model:
	Unexpected key(s) in state_dict: "E.Harm.dct", "E.Harm.weight", "E.Harm.bias".

where net_model is simply a renamed model class from your modules.py file:

class net_model(nn.Module):
    def __init__(self, Encoder, num_features, block_channel):

        super(net_model, self).__init__()

        self.E = Encoder
        self.D2 = D2(num_features = num_features)
        self.MFF = MFF(block_channel)
        self.R = R(block_channel)


    def forward(self, x):
        x_block0, x_block1, x_block2, x_block3, x_block4 = self.E(x)
  
        x= x_block0.view(-1,250,250)

        x = x.cpu().detach().numpy()
        
        #for idx in range(0,len(x)):
        #    x[idx] = x[idx]*100000
        #    np.clip(x[idx], 0, 50000).astype(np.uint16)
        #    filename = str(idx)+'.png'
        #    cv2.imwrite(filename, x[idx]) 
         
        
        x_decoder = self.D2(x_block0, x_block1, x_block2, x_block3, x_block4) 


        
        x_mff = self.MFF(x_block0, x_block1, x_block2, x_block3, x_block4,[x_decoder.size(2),x_decoder.size(3)]) 


 
        out = self.R(torch.cat((x_decoder, x_mff), 1)) 
        return out

For simplicity, here is the E_senet definition too:

class E_senet(nn.Module):

    def __init__(self, original_model, num_features = 2048):
        super(E_senet, self).__init__()

        self.base = nn.Sequential(*list(original_model.children())[:-3])

        #self.conv = nn.Conv2d(3, 64 , kernel_size=5, stride=1, bias=False)
        #self.bn = nn.BatchNorm2d(64)
      
        self.pool = nn.MaxPool2d(3, stride=2,ceil_mode=True)
        self.down = _UpProjection(64,128)

    def forward(self, x):
        #conv_x = F.relu(self.conv(x))
        #conv_x = self.bn(conv_x) 

        #summary(self.base, input_size=(3, 440, 440))
        x_block0 = self.base[0][0:6](x)
        x = self.base[0][6:](x_block0)
        

        # x = self.Harm(x)
        # x = self.pool(x)
        # x = self.down(x,(110,110))

        x_block1 = self.base[1](x)
        x_block2 = self.base[2](x_block1)
        x_block3 = self.base[3](x_block2)
        x_block4 = self.base[4](x_block3)
        return x_block0,  x_block1, x_block2, x_block3, x_block4 

Because of the error listed above, I'm assuming there's a problem with how I've defined model because the E_senet encoder has no harm property. But that's as far as I've gotten. So, am I defining the model incorrectly for the trained weights? Any help/direction would be appreciated, and thanks for your time!

szwiep avatar Mar 22 '22 21:03 szwiep

I had this same issue. Apparently it happens beacuse the model has been saved using the nn.DataParallel function. You can solve this by loading the model with strict=False like this:

original_model = senet.senet154(pretrained='imagenet')
Encoder = modules.E_senet(original_model)
model = net.model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048])

MODEL_PATH = "./models/Block0_skip_model_110.pth.tar"
model.load_state_dict(torch.load(MODEL_PATH)['state_dict'], strict=False)

Have you managed to train any effective model using the network in the repo? No matter how much I try, I cannot get the network to learn properly.

TorsteinOtterlei avatar Apr 17 '22 19:04 TorsteinOtterlei

I had this same issue. Apparently it happens beacuse the model has been saved using the nn.DataParallel function. You can solve this by loading the model with strict=False like this:

original_model = senet.senet154(pretrained='imagenet')
Encoder = modules.E_senet(original_model)
model = net.model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048])

MODEL_PATH = "./models/Block0_skip_model_110.pth.tar"
model.load_state_dict(torch.load(MODEL_PATH)['state_dict'], strict=False)

Have you managed to train any effective model using the network in the repo? No matter how much I try, I cannot get the network to learn properly.

have you able to run the code ?

Arham222 avatar Nov 24 '22 17:11 Arham222