torchgeo icon indicating copy to clipboard operation
torchgeo copied to clipboard

Multi-Weight Support API

Open adamjstewart opened this issue 3 years ago • 0 comments

Summary

We should consider moving towards a multi-weight support API for pre-trained model weights.

Rationale

Our pre-trained weight support was modeled after how torchvision used to handle weights:

model = resnet50(pretrained=True)

However, model weights are not boolean, there are many different potential weights for:

  • Different satellites (e.g., Sentinel-2, Landsat 8)
  • Band sets (e.g., all bands, RGB-only, false color)
  • Downstream tasks (e.g., land cover mapping, object detection)
  • Training methods (e.g., supervised, self-supervised)

Currently, we use the following API:

model = resnet50(sensor="sentinel2", bands="all", pretrained=True)

However, this has several drawbacks:

  1. It's not extensible if we want to add a different downstream task or training strategy
  2. It's unclear and undocumented what all available pre-trained models are

Implementation

Torchvision recently added a multi-weight support API: https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/

We should consider emulating torchvision using enums for all available pre-trained model weights.

Alternatives

Alternative is "business as usual".

Additional information

I don't think this would require us to drop support for torchvision 0.12 and older, although it may make things simpler.

adamjstewart avatar Sep 06 '22 16:09 adamjstewart