torchgeo icon indicating copy to clipboard operation
torchgeo copied to clipboard

pretrained_weights example will not work with the timm model

Open robmarkcole opened this issue 2 years ago • 7 comments

Issue

These docs describe creation of a timm model if you do not want to use ClassificationTask, but if you replace the final cell:

trainer.fit(model=task -> trainer.fit(model=model

You will get the error:

TypeError: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `_IncompatibleKeys

Fix

Not sure yet!

robmarkcole avatar Oct 10 '23 15:10 robmarkcole

I think the wording may be misleading, but it's not suggesting you use a timm model with PyTorch Lightning. It's simply suggesting that if you have your own custom training framework (not PL) you can load the pretrained weights into a base nn.Module.

isaaccorley avatar Oct 10 '23 17:10 isaaccorley

It would be nice if we could figure out #996. It doesn't help that Trainer takes a model as input and we conveniently create a model in the previous cell. But model and model are unrelated constructs.

adamjstewart avatar Oct 10 '23 19:10 adamjstewart

Could this be a possible solution? https://stackoverflow.com/a/59042066

model = timm.create_model("vit_small_patch16_224", in_chans=in_chans) # model is loaded
 # model variable is assigned as _IncompatibleKeys
model = model.load_state_dict(weights.get_state_dict(progress=True), strict=False)

Instead, load_state_dict should be called without assigning it to the model variable.

model = timm.create_model("vit_small_patch16_224", in_chans=in_chans) # model is loaded
 # load_state_dict modifies the model itself, rather than assigning the outputs of the function to the model var
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)

clkruse avatar Jan 03 '24 16:01 clkruse

Seems like it might be fixed in https://github.com/microsoft/torchgeo/pull/1503?

clkruse avatar Jan 03 '24 16:01 clkruse

I'm getting a similar error with the following snippet that uses the SemanticSegmentationTask.

from pytorch_lightning import Trainer
from torchgeo.datamodules import InriaAerialImageLabelingDataModule
from torchgeo.trainers import SemanticSegmentationTask

datamodule = InriaAerialImageLabelingDataModule(root_dir="./inria", batch_size=64, num_workers=6)
task = SemanticSegmentationTask(model="unet", backbone="resnet18", lr=0.1)
trainer = Trainer(default_root_dir="./unet_trainer", max_epochs=1)
trainer.fit(model=task, datamodule=datamodule)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[27], line 8
      6 task = SemanticSegmentationTask(model="unet", backbone="resnet18", lr=0.1)
      7 trainer = Trainer(default_root_dir="./unet_trainer", max_epochs=1)
----> 8 trainer.fit(model=task, datamodule=datamodule)

File ~/.cache/pypoetry/virtualenvs/wbc-model-SkEs5vNN-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:538, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    504 def fit(
    505     self,
    506     model: "pl.LightningModule",
   (...)
    510     ckpt_path: Optional[str] = None,
    511 ) -> None:
    512     r"""Runs the full optimization routine.
    513 
    514     Args:
   (...)
    536 
    537     """
--> 538     model = _maybe_unwrap_optimized(model)
    539     self.strategy._lightning_module = model
    540     _verify_strategy_supports_compile(model, self.strategy)

File ~/.cache/pypoetry/virtualenvs/wbc-model-SkEs5vNN-py3.10/lib/python3.10/site-packages/pytorch_lightning/utilities/compile.py:132, in _maybe_unwrap_optimized(model)
    130     return model
    131 _check_mixed_imports(model)
--> 132 raise TypeError(
    133     f"`model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `{type(model).__qualname__}`"
    134 )

TypeError: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `SemanticSegmentationTask`

the task is a LightningModule but task.model is not.

I can't find an example of semantic segmentation so I tried to follow the structure of the classification task example in the documentation, but use a backbone instead of specifying weights since there are no pretrained semantic segmentation weights.

rbavery avatar Jan 19 '24 23:01 rbavery

Could you retry with from lightning.pytorch import Trainer and import lightning.pytorch as pl? I've seen the switch from pytorch-lightning to lightning cause problems with saved models.

calebrob6 avatar Jan 19 '24 23:01 calebrob6

that fixed it thanks @calebrob6

rbavery avatar Jan 19 '24 23:01 rbavery

I think this was just a user mistake, we can close this issue.

adamjstewart avatar Feb 29 '24 11:02 adamjstewart