pytorch-image-models icon indicating copy to clipboard operation
pytorch-image-models copied to clipboard

[FEATURE] timm.models.adapt_input_conv: beyond RGB weights

Open adamjstewart opened this issue 9 months ago • 6 comments

Is your feature request related to a problem? Please describe.

TorchGeo provides a number of model weights pre-trained on non-RGB imagery (e.g., Sentinel-2, 13 channels). Oftentimes, when dealing with time-series data, we would like to stack images along the channel dimension so that we end up with $$B \times TC \times H \times W$$ inputs. However, we don't yet have an easy way to adapt our pre-trained weights to match.

Describe the solution you'd like

timm.models.adapt_input_conv provides a powerful tool for repeating and scaling weights to adapt to changing in_chans, but only seems to support 3-channel weights if in_chans > 1. I would like to extend this to support any number of channels. Would this be as simple as replacing 3 with I throughout the function?

Describe alternatives you've considered

We could write our own functionality in TorchGeo, but figured this would be useful to the broader timm community.

Additional context

@isaaccorley @keves1 may also be interested in this.

adamjstewart avatar Feb 25 '25 09:02 adamjstewart

@adamjstewart So if you had a model with 13 channels, you might want to convert it so those 3 channels are repeated several times and pass it say a 13*4=52 channel image?

rwightman avatar Feb 25 '25 15:02 rwightman

Correct, other than a typo (3 -> 13).

Another use case is when a model is trained on top of atmosphere data (13 channels) but we want to use it for surface reflectance data (11 channels) or vice versa. In this case, we know which specific channels were removed, but would be satisfied with a solution that just naively drops the last 2 channels and rescales things.

adamjstewart avatar Feb 25 '25 16:02 adamjstewart

@adamjstewart k, I think it's pretty straightforward to support that with an extra arg that covers the 'base' or default channels. Below I added base_chans arg... if you set it to 13 you should get the behaviour described. The 'naive' dropping is the default. For dropping specific channels, an arg could be added with indices to drop.

Conceivable you could drop before or after the repeat say if you had 11 of 13 channels you wanted to keep, AND wanted to repeat to have 33 channel input... hmm.

def adapt_input_conv(in_chans: int, conv_weight: Tensor, base_chans: int = 3) -> Tensor:
    conv_type = conv_weight.dtype
    conv_weight = conv_weight.float()  # Some weights are in torch.half, ensure it's float for sum on CPU
    O, I, J, K = conv_weight.shape
    if in_chans == 1:
        if I > base_chans:
            assert conv_weight.shape[1] % base_chans == 0
            # For models with space2depth stems
            conv_weight = conv_weight.reshape(O, I // base_chans, base_chans, J, K)
            conv_weight = conv_weight.sum(dim=2, keepdim=False)
        else:
            conv_weight = conv_weight.sum(dim=1, keepdim=True)
    elif in_chans != base_chans:
        if I != base_chans:
            raise NotImplementedError('Weight format not supported by conversion.')
        else:
            # NOTE this strategy should be better than random init, but there could be other combinations of
            # the original RGB input layer weights that'd work better for specific cases.
            repeat = int(math.ceil(in_chans / base_chans))
            scale = base_chans / float(in_chans)
            if repeat > 1:
                conv_weight = conv_weight.repeat(1, repeat, 1, 1)
            conv_weight = conv_weight[:, :in_chans, :, :]  # drops last channels
            conv_weight *= scale
    conv_weight = conv_weight.to(conv_type)
    return conv_weight

rwightman avatar Feb 25 '25 17:02 rwightman

Would it be simpler to use base_chans = I? I know your implementation better supports 13 -> 11 -> 33, but I think that specific use case will be uncommon. I'm mostly concerned about 13 -> 26 or 13 -> 11 or 11 -> 13.

adamjstewart avatar Feb 25 '25 18:02 adamjstewart

@adamjstewart removing the base_chans arg would require adding a space2depth multiplier arg to resolve that ambiguity to support monochrome use with tresnet models...

rwightman avatar Feb 25 '25 19:02 rwightman

I don't know what most of those words mean so I'll trust you. However, your implementation doesn't seem to support anything other than in_chans == 1 or base_chans == I, which is why I'm wondering why we even need a base_chans parameter at all. If base_chans = I, then most if-statements and raises disappear.

adamjstewart avatar Feb 26 '25 10:02 adamjstewart