pretrained_weights example will not work with the timm model
Issue
These docs describe creation of a timm model if you do not want to use ClassificationTask, but if you replace the final cell:
trainer.fit(model=task -> trainer.fit(model=model
You will get the error:
TypeError: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `_IncompatibleKeys
Fix
Not sure yet!
I think the wording may be misleading, but it's not suggesting you use a timm model with PyTorch Lightning. It's simply suggesting that if you have your own custom training framework (not PL) you can load the pretrained weights into a base nn.Module.
It would be nice if we could figure out #996. It doesn't help that Trainer takes a model as input and we conveniently create a model in the previous cell. But model and model are unrelated constructs.
Could this be a possible solution? https://stackoverflow.com/a/59042066
model = timm.create_model("vit_small_patch16_224", in_chans=in_chans) # model is loaded
# model variable is assigned as _IncompatibleKeys
model = model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
Instead, load_state_dict should be called without assigning it to the model variable.
model = timm.create_model("vit_small_patch16_224", in_chans=in_chans) # model is loaded
# load_state_dict modifies the model itself, rather than assigning the outputs of the function to the model var
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
Seems like it might be fixed in https://github.com/microsoft/torchgeo/pull/1503?
I'm getting a similar error with the following snippet that uses the SemanticSegmentationTask.
from pytorch_lightning import Trainer
from torchgeo.datamodules import InriaAerialImageLabelingDataModule
from torchgeo.trainers import SemanticSegmentationTask
datamodule = InriaAerialImageLabelingDataModule(root_dir="./inria", batch_size=64, num_workers=6)
task = SemanticSegmentationTask(model="unet", backbone="resnet18", lr=0.1)
trainer = Trainer(default_root_dir="./unet_trainer", max_epochs=1)
trainer.fit(model=task, datamodule=datamodule)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[27], line 8
6 task = SemanticSegmentationTask(model="unet", backbone="resnet18", lr=0.1)
7 trainer = Trainer(default_root_dir="./unet_trainer", max_epochs=1)
----> 8 trainer.fit(model=task, datamodule=datamodule)
File ~/.cache/pypoetry/virtualenvs/wbc-model-SkEs5vNN-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:538, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
504 def fit(
505 self,
506 model: "pl.LightningModule",
(...)
510 ckpt_path: Optional[str] = None,
511 ) -> None:
512 r"""Runs the full optimization routine.
513
514 Args:
(...)
536
537 """
--> 538 model = _maybe_unwrap_optimized(model)
539 self.strategy._lightning_module = model
540 _verify_strategy_supports_compile(model, self.strategy)
File ~/.cache/pypoetry/virtualenvs/wbc-model-SkEs5vNN-py3.10/lib/python3.10/site-packages/pytorch_lightning/utilities/compile.py:132, in _maybe_unwrap_optimized(model)
130 return model
131 _check_mixed_imports(model)
--> 132 raise TypeError(
133 f"`model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `{type(model).__qualname__}`"
134 )
TypeError: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `SemanticSegmentationTask`
the task is a LightningModule but task.model is not.
I can't find an example of semantic segmentation so I tried to follow the structure of the classification task example in the documentation, but use a backbone instead of specifying weights since there are no pretrained semantic segmentation weights.
Could you retry with from lightning.pytorch import Trainer and import lightning.pytorch as pl? I've seen the switch from pytorch-lightning to lightning cause problems with saved models.
that fixed it thanks @calebrob6
I think this was just a user mistake, we can close this issue.