MedicalNet
MedicalNet copied to clipboard
Is there any classification code?
What I'm doing is replace the FC layers to my classification layers, although the performance was not good. Happy to discuss more if you are interested in.
I used this:
class MedicalNet(nn.Module):
def __init__(self, path_to_weights, device):
super(MedicalNet, self).__init__()
self.model = resnet34(sample_input_D=1, sample_input_H=112, sample_input_W=112, num_seg_classes=2)
self.model.conv_seg = nn.Sequential(
nn.AdaptiveMaxPool3d(output_size=(1, 1, 1)),
nn.Flatten(start_dim=1),
nn.Dropout(0.1)
)
net_dict = self.model.state_dict()
pretrained_weights = torch.load(path_to_weights, map_location=torch.device(device))
pretrain_dict = {
k.replace("module.", ""): v for k, v in pretrained_weights['state_dict'].items() if k.replace("module.", "") in net_dict.keys()
}
net_dict.update(pretrain_dict)
self.model.load_state_dict(net_dict)
self.fc = nn.Linear(512, 1)
def forward(self, x):
features = self.model(x)
return self.fc(features)
Then:
model = MedicalNet(path_to_weights="pretrain/resnet_34.pth", device=device)
for param_name, param in model.named_parameters():
if param_name.startswith("conv_seg"):
param.requires_grad = True
else:
param.requires_grad = False
Hi @JasperHG90 Do you have the code for training classification? because in train.py there are some parts that connected to the segmentation for example masks etc
@Batush123 I'm not entirely sure what you're asking for. Are you asking me what my input data & training loop look like?
Hi @JasperHG90, I am looking at a similar problem and I would be glad if you could share your code including the data/training loop, if possible. Thanks!
Hello @JasperHG90, same here, would you be able to share your data/training loop? Thank you very much :)