IMELE
IMELE copied to clipboard
Loading Trained Weights to is_senet=True Model
Good afternoon! Thanks for sharing such exciting work. I've been trying to load in your trained weights to a model
defined by the following (is_senet=True
):
original_model = senet154(pretrained='imagenet')
Encoder = E_senet(original_model)
model = net_model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048])
But PyTorch.load() is failing on some missing keys:
RuntimeError: Error(s) in loading state_dict for net_model:
Unexpected key(s) in state_dict: "E.Harm.dct", "E.Harm.weight", "E.Harm.bias".
where net_model
is simply a renamed model
class from your modules.py
file:
class net_model(nn.Module):
def __init__(self, Encoder, num_features, block_channel):
super(net_model, self).__init__()
self.E = Encoder
self.D2 = D2(num_features = num_features)
self.MFF = MFF(block_channel)
self.R = R(block_channel)
def forward(self, x):
x_block0, x_block1, x_block2, x_block3, x_block4 = self.E(x)
x= x_block0.view(-1,250,250)
x = x.cpu().detach().numpy()
#for idx in range(0,len(x)):
# x[idx] = x[idx]*100000
# np.clip(x[idx], 0, 50000).astype(np.uint16)
# filename = str(idx)+'.png'
# cv2.imwrite(filename, x[idx])
x_decoder = self.D2(x_block0, x_block1, x_block2, x_block3, x_block4)
x_mff = self.MFF(x_block0, x_block1, x_block2, x_block3, x_block4,[x_decoder.size(2),x_decoder.size(3)])
out = self.R(torch.cat((x_decoder, x_mff), 1))
return out
For simplicity, here is the E_senet definition too:
class E_senet(nn.Module):
def __init__(self, original_model, num_features = 2048):
super(E_senet, self).__init__()
self.base = nn.Sequential(*list(original_model.children())[:-3])
#self.conv = nn.Conv2d(3, 64 , kernel_size=5, stride=1, bias=False)
#self.bn = nn.BatchNorm2d(64)
self.pool = nn.MaxPool2d(3, stride=2,ceil_mode=True)
self.down = _UpProjection(64,128)
def forward(self, x):
#conv_x = F.relu(self.conv(x))
#conv_x = self.bn(conv_x)
#summary(self.base, input_size=(3, 440, 440))
x_block0 = self.base[0][0:6](x)
x = self.base[0][6:](x_block0)
# x = self.Harm(x)
# x = self.pool(x)
# x = self.down(x,(110,110))
x_block1 = self.base[1](x)
x_block2 = self.base[2](x_block1)
x_block3 = self.base[3](x_block2)
x_block4 = self.base[4](x_block3)
return x_block0, x_block1, x_block2, x_block3, x_block4
Because of the error listed above, I'm assuming there's a problem with how I've defined model
because the E_senet
encoder has no harm
property. But that's as far as I've gotten. So, am I defining the model incorrectly for the trained weights? Any help/direction would be appreciated, and thanks for your time!
I had this same issue. Apparently it happens beacuse the model has been saved using the nn.DataParallel function. You can solve this by loading the model with strict=False like this:
original_model = senet.senet154(pretrained='imagenet')
Encoder = modules.E_senet(original_model)
model = net.model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048])
MODEL_PATH = "./models/Block0_skip_model_110.pth.tar"
model.load_state_dict(torch.load(MODEL_PATH)['state_dict'], strict=False)
Have you managed to train any effective model using the network in the repo? No matter how much I try, I cannot get the network to learn properly.
I had this same issue. Apparently it happens beacuse the model has been saved using the nn.DataParallel function. You can solve this by loading the model with strict=False like this:
original_model = senet.senet154(pretrained='imagenet') Encoder = modules.E_senet(original_model) model = net.model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048]) MODEL_PATH = "./models/Block0_skip_model_110.pth.tar" model.load_state_dict(torch.load(MODEL_PATH)['state_dict'], strict=False)
Have you managed to train any effective model using the network in the repo? No matter how much I try, I cannot get the network to learn properly.
have you able to run the code ?