torchgeo icon indicating copy to clipboard operation
torchgeo copied to clipboard

Support model in_chans not equal to pre-trained weights in_chans

Open adamjstewart opened this issue 5 months ago • 1 comments

Summary

If a user specifies in_chans and weights, and weights.meta['in_chans'] differs from in_chans, the user-specified argument should take precedence and weights should be repeated, similar to how timm handles pre-trained weights.

Rationale

When working on change detection, it is common to take two images and stack them along the channel dimension. However, this makes it impossible to use our pre-trained weights. Ideally, I would like to support something like:

from torchgeo.models import ResNet50_Weights, resnet50

model = resnet50(in_chans=4, weights=ResNet50_Weights.SENTINEL1_ALL_MOCO)

Here, the weights have 2 channels (HH and HV), while the dataset and model will have 4 channels (HH, HV, HH, HV).

Implementation

https://timm.fast.ai/models#Case-2:-When-the-number-of-input-channels-is-not-1 describes the implementation that timm uses. This can be imported as:

from timm.models.helpers import load_pretrained

We should make use of this in all of our model definitions instead of model.load_state_dict.

Alternatives

There is some ongoing work to add a ChangeDetectionTask that may split each image into a separate sample key. However, there will always be models that require images stacked along the channel dimension, so I don't think we can avoid supporting this use case.

Additional information

No response

adamjstewart avatar Sep 09 '24 10:09 adamjstewart