torchgeo
torchgeo copied to clipboard
Multi-Weight Support API
Summary
We should consider moving towards a multi-weight support API for pre-trained model weights.
Rationale
Our pre-trained weight support was modeled after how torchvision used to handle weights:
model = resnet50(pretrained=True)
However, model weights are not boolean, there are many different potential weights for:
- Different satellites (e.g., Sentinel-2, Landsat 8)
- Band sets (e.g., all bands, RGB-only, false color)
- Downstream tasks (e.g., land cover mapping, object detection)
- Training methods (e.g., supervised, self-supervised)
Currently, we use the following API:
model = resnet50(sensor="sentinel2", bands="all", pretrained=True)
However, this has several drawbacks:
- It's not extensible if we want to add a different downstream task or training strategy
- It's unclear and undocumented what all available pre-trained models are
Implementation
Torchvision recently added a multi-weight support API: https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/
We should consider emulating torchvision using enums for all available pre-trained model weights.
Alternatives
Alternative is "business as usual".
Additional information
I don't think this would require us to drop support for torchvision 0.12 and older, although it may make things simpler.