torchxrayvision
                                
                                 torchxrayvision copied to clipboard
                                
                                    torchxrayvision copied to clipboard
                            
                            
                            
                        transfer learning for resnet50-res512-all
Great library! Would you provide the transfer learning code for resnet50-res512-all as well? Thank you so much!
It should be as simple as using this line:
model = xrv.models.ResNet(weights="resnet50-res512-all")
in this script: https://github.com/mlmed/torchxrayvision/blob/master/scripts/transfer_learning.ipynb
Also change the resizing to xrv.datasets.XRayResizer(512) so the images are 512x512
Hi Joseph,
Thanks for the response! I used densenet121(224 X 224) for transfer learning and it worked great. I tried ResNet50 and modified the fc layers for binary classification as below:
model = xrv.models.ResNet(weights="resnet50-res512-all") model.fc = nn.Sequential( nn.Linear(2048, 128), nn.ReLU(inplace=True), nn.Linear(128,1))
But the script stuck at the training step: outputs = model(inputs)
The dimension of output becomes (32, 18), I know 32 is the batch size but I don't know where 18 comes from. Shouldn't it just be 1 instead?
It seems to me the settings for resnet are quite different from densenet. I am quite new to this and hope to get the resnet work, thank you for helping out!
Yue
On Mon, Apr 4, 2022 at 4:22 PM Joseph Paul Cohen @.***> wrote:
It should be as simple as using this line:
model = xrv.models.ResNet(weights="resnet50-res512-all")
in this script: https://github.com/mlmed/torchxrayvision/blob/master/scripts/transfer_learning.ipynb
Also change the resizing to xrv.datasets.XRayResizer(512) so the images are 512x512
— Reply to this email directly, view it on GitHub https://github.com/mlmed/torchxrayvision/issues/92#issuecomment-1088026128, or unsubscribe https://github.com/notifications/unsubscribe-auth/AQF6CBYSOEXAOGJZ22JQKW3VDNMSTANCNFSM5SOHCWOA . You are receiving this because you authored the thread.Message ID: @.***>
Oh sorry I responded to fast and didn't test the code. The resnet loads an internal resnet model inside so the fc is located at model.model.fc.
model = xrv.models.ResNet(weights="resnet50-res512-all")
model.op_threshs = None # prevent pre-trained model calibration
model.model.fc = torch.nn.Linear(2048,1) # reinitialize classifier
optimizer = torch.optim.Adam(model.model.fc.parameters()) # only train classifier
criterion = torch.nn.BCEWithLogitsLoss()
I tested the above code and it seems to train correctly.