lightning-flash
lightning-flash copied to clipboard
Finetune `SemanticSegmentation` with `ImageEmbedder` backbone
Discussed in https://github.com/PyTorchLightning/lightning-flash/discussions/1288
Originally posted by sirtris April 11, 2022
Hi,
I have pre-trained my own network backbone using the ImageEmbedder. I would now use this model in a down-stream segmentation task. How could I do this?
For a classification task I would try:
model = ImageClassifier.load_from_checkpoint("image_classification_model.pt")
The problem is that the SemanticSegmentation() has no function load_from_checkpoint
Thanks
For the meantime are there any suggestions for a workaround? I usually use the Segmentation Models Pytorch library. Is there a way to use the backbone in one of their models?
hi @sirtris, the Segmentation task uses SMP for creating the backbone (encoder) and head (decoder) of the segmentation model. And in-order to create a head you need to pass the backbone to the head constructor.
To use a custom backbone you might have to make change to create_model(...) function of the SMP library instead.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.