anomalib
anomalib copied to clipboard
Model Loader Callback
Is your feature request related to a problem? Please describe.
- Currently, the Pytorch Lightning(PL) trainer uses a callback
LoadModelCallback
fromanomalib/utils/callbacks/model_loader.py
where it uses Pytorch'storch.load
function to load the best weights which may cause device-related issues (trained on GPU, testing/predicting on CPU).
Describe the solution you'd like
- I suggest removing Pytorch's
torch.load
function fromLoadModelCallback
and letting the PL trainer handle this issue.
trainer.test(model=model, datamodule=datamodule, ckpt_path='best') #or
trainer.test(model=model, datamodule=datamodule, ckpt_path='$path_to_the_checkpoint_user_wish_to_test')
trainer.predict(model=model, datamodule=datamodule, ckpt_path='best') #or
trainer.predict(model=model, datamodule=datamodule, ckpt_path='$path_to_the_checkpoint_user_wish_to_predict')
Reference: [PL docs]
Thanks for your suggestion! This could indeed be a better way of loading the model weights, though we will have to investigate if this would lead to any unwanted/unexpected behavior. We'll have a look and post any findings here.
Hi @shakib-root, the reason why we used this was because there was a bug in PL model loading in its earlier versions. We were unable to achieve the same performance as training. We could try again now.
I am closing this as this change has been merged to the feature branch and will be merged to main soon.