torchgeo icon indicating copy to clipboard operation
torchgeo copied to clipboard

Support model in_chans not equal to pre-trained weights in_chans

Open keves1 opened this issue 1 year ago • 6 comments

This PR addresses #2289.

I've made a good start on this and am creating this draft PR to get some feedback before continuing. I have created a load_pretrained function that will copy weights as done by timm, but which supports any number of in_chans of the weights, not just 3. So far I just implemented this new functionality in resnet50 and will expand that to other models if this is a good approach.

With this change, you can create a model like this, in this case the two channels will be copied into 4:

model = resnet50(weights=ResNet50_Weights.SENTINEL1_ALL_MOCO, in_chans=4)

I created a new file torchgeo/models/utils.py where I put the load_pretrained function since I didn't see any suitable existing file. Is there a better place for it?

keves1 avatar Sep 27 '24 16:09 keves1

@microsoft-github-policy-service agree

keves1 avatar Sep 27 '24 16:09 keves1

I have created a load_pretrained function that will copy weights as done by timm, but which supports any number of in_chans of the weights, not just 3.

Why can't we use timm.models.helpers.load_pretrained instead of writing our own custom code?

adamjstewart avatar Sep 28 '24 13:09 adamjstewart

Why can't we use timm.models.helpers.load_pretrained instead of writing our own custom code?

Because that function will only copy the weights to additional input channels if the first convolution layer of the weights has 3 input channels. Otherwise it raises a NotImplementedError exception and randomly initializes the weights. I linked to the timm implementation in my comment on issue #2289 .

keves1 avatar Sep 28 '24 15:09 keves1

Will try to review next week, this week is quite busy. Apologies for the wait!

adamjstewart avatar Oct 01 '24 20:10 adamjstewart

I recently discovered timm.models.adapt_input_conv, which makes it easy to change the number of channels without losing pretrained weights. See #2602 for an example. There seems to be a lot of other builtin stuff in timm we might be able to make use of. If it doesn't support everything we need, I would be willing to try to add support for it in timm.

adamjstewart avatar Feb 22 '25 12:02 adamjstewart

I actually looked into timm.models.adapt_input_conv (this is what is called by timm.models.helpers.load_pretrained) and it only works if the if the first convolution layer of the weights has 3 input channels (see my comment above). So we couldn't use adapt_input_conv with weights trained on multispectral or other non 3 channel imagery.

Here's an example:

from timm.models import adapt_input_conv
from torchgeo.models import ResNet18_Weights

weights = ResNet18_Weights.SENTINEL2_ALL_MOCO.get_state_dict()
adapt_input_conv(26, weights['conv1.weight'])
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[17], line 1
----> 1 adapt_input_conv(26, weights['conv1.weight'])

File ~/miniconda3/envs/torchgeo-env/lib/python3.11/site-packages/timm/models/_manipulate.py:270, in adapt_input_conv(in_chans, conv_weight)
    268 elif in_chans != 3:
    269     if I != 3:
--> 270         raise NotImplementedError('Weight format not supported by conversion.')
    271     else:
    272         # NOTE this strategy should be better than random init, but there could be other combinations of
    273         # the original RGB input layer weights that'd work better for specific cases.
    274         repeat = int(math.ceil(in_chans / 3))

NotImplementedError: Weight format not supported by conversion.

There may be other methods in timm which we could use like you mentioned, or may need to add support. I'm focusing on trying to get a PR ready for adding the Autoregression trainer, so I won't be working on this more in the short term.

keves1 avatar Feb 24 '25 22:02 keves1