torchxrayvision icon indicating copy to clipboard operation
torchxrayvision copied to clipboard

transfer learning for resnet50-res512-all

Open dev-yue opened this issue 3 years ago • 3 comments

Great library! Would you provide the transfer learning code for resnet50-res512-all as well? Thank you so much!

dev-yue avatar Apr 04 '22 03:04 dev-yue

It should be as simple as using this line:

model = xrv.models.ResNet(weights="resnet50-res512-all")

in this script: https://github.com/mlmed/torchxrayvision/blob/master/scripts/transfer_learning.ipynb

Also change the resizing to xrv.datasets.XRayResizer(512) so the images are 512x512

ieee8023 avatar Apr 04 '22 21:04 ieee8023

Hi Joseph,

Thanks for the response! I used densenet121(224 X 224) for transfer learning and it worked great. I tried ResNet50 and modified the fc layers for binary classification as below:

model = xrv.models.ResNet(weights="resnet50-res512-all") model.fc = nn.Sequential( nn.Linear(2048, 128), nn.ReLU(inplace=True), nn.Linear(128,1))

But the script stuck at the training step: outputs = model(inputs)

The dimension of output becomes (32, 18), I know 32 is the batch size but I don't know where 18 comes from. Shouldn't it just be 1 instead?

It seems to me the settings for resnet are quite different from densenet. I am quite new to this and hope to get the resnet work, thank you for helping out!

Yue

On Mon, Apr 4, 2022 at 4:22 PM Joseph Paul Cohen @.***> wrote:

It should be as simple as using this line:

model = xrv.models.ResNet(weights="resnet50-res512-all")

in this script: https://github.com/mlmed/torchxrayvision/blob/master/scripts/transfer_learning.ipynb

Also change the resizing to xrv.datasets.XRayResizer(512) so the images are 512x512

— Reply to this email directly, view it on GitHub https://github.com/mlmed/torchxrayvision/issues/92#issuecomment-1088026128, or unsubscribe https://github.com/notifications/unsubscribe-auth/AQF6CBYSOEXAOGJZ22JQKW3VDNMSTANCNFSM5SOHCWOA . You are receiving this because you authored the thread.Message ID: @.***>

dev-yue avatar Apr 04 '22 21:04 dev-yue

Oh sorry I responded to fast and didn't test the code. The resnet loads an internal resnet model inside so the fc is located at model.model.fc.

model = xrv.models.ResNet(weights="resnet50-res512-all")
model.op_threshs = None # prevent pre-trained model calibration
model.model.fc = torch.nn.Linear(2048,1) # reinitialize classifier

optimizer = torch.optim.Adam(model.model.fc.parameters()) # only train classifier
criterion = torch.nn.BCEWithLogitsLoss()

I tested the above code and it seems to train correctly.

ieee8023 avatar Apr 04 '22 21:04 ieee8023