lightning-bolts icon indicating copy to clipboard operation
lightning-bolts copied to clipboard

Rethink usage pattern for pretrained models

Open nateraw opened this issue 4 years ago • 5 comments

🚀 Feature

Switch to using SomeModel.from_pretrained('pretrained-model-name') for pretrained models.

Motivation

Seems we are following torchvision's pattern of having a 'pretrained' argument in the init of our models to initialize a pretrained model. In my opinion, this is extremely confusing as it makes the other init args + kwargs ambiguous/useless.

Pitch

add .from_pretrained classmethod to models and initialize an instance of the class based off of that. Pretrained models should incorporate any hparams needed to fill out init, I guess.

from pl_bolts.models import VAE

model = VAE.from_pretrained('imagenet2012')

Alternatives

Additional context

nateraw avatar Sep 10 '20 22:09 nateraw

yeah, agree... although this is basically just the same as load_from_checkpoint no? sounds like we're looking for checkpoint nicknames instead?

doesn't it read better as:

VAE.pretrained_on('xyz')

williamFalcon avatar Sep 11 '20 01:09 williamFalcon

Right, I think the distinction here is that load_from_checkpoint is for checkpoints you have saved locally, but this function would be for pretrained models that we are hosting (i.e. these guys).

So, yes! We are looking for something that can point to a nickname/identifier for a pretrained model.


I think 'pretrained_on' is a limiting name, as a model could be pretrained on the same dataset twice w/ different settings, and then would be ambiguous to load if using that function name. Thats why I suggest something a little more open, such as from_pretrained(identifier).

This is just my opinion... I could be convinced otherwise haha 😄 . Let's have others weigh in to come to consensus.

CC: @PyTorchLightning/core-contributors

nateraw avatar Sep 11 '20 01:09 nateraw

oh i see. it's an id not a dataset. yeah that works.

for instance we can have many backbones with different datasets as well

CPC.from_pretrained('resnet18-imagenet')
CPC.from_pretrained('resnet50-imagenet')
CPC.from_pretrained('resnet18-stl10')

williamFalcon avatar Sep 11 '20 01:09 williamFalcon

Yes, they are trained on a defined dataset, in this case, the dataset name serves just as Look-up-table to a specific path on PL side...

Borda avatar Sep 11 '20 13:09 Borda

@williamFalcon @Borda @nateraw I included this pattern in the latest AE, VAE commits to bolts. Few points that I realized:

  1. We can shift the method from_pretrained() as a method to override in Lightning itself.
  2. from_pretrained() needs to be an instance method and not a static method. In most cases, you will initialize the lightning module with specific params according the the weights being loaded.
vae = VAE(input_height=32, first_conv=True)
vae = vae.from_pretrained('cifar10-resnet18')

In this example stl10 weights have a different configuration for the encoder of the VAE. But, at the same time the internal method has a strict=False flag while loading so that users can load stl10 weights to the encoder configuration of cifar10 dataset.

  1. Having this pattern allows us to test the correct loading of weights using the from_pretrained() function. @williamFalcon cases like the corrupt ImageNet weights for CPC will be caught automatically.

I have added all of this + tests for the AE and VAE classes I have updated for bolts.

ananyahjha93 avatar Sep 12 '20 19:09 ananyahjha93