DOFA
DOFA copied to clipboard
Add more examples using torchgeo for segmentation
Hi, I'm very interested in this project. Can you add a more profound example using Torchgeo to load the pre-trained weights and use it for segmentation? I've tried to do some experiments using a custom dataset, but I'm not achieving good results.
Currently I'm doing something like this:
class SegmentationNet(nn.Module):
def __init__(self):
super(SegmentationNet, self).__init__()
self.backbone = dofa_large_patch16_224(img_size=224, weights=DOFALarge16_Weights.DOFA_MAE)
self.decoder = CustomDecoder(in_channels=1024, num_classes=1)
def forward(self, x):
# S2 wavelengths
wavelengths = [2.20, 1.61, 0.865]
x = self.backbone.forward_features(x, wavelengths)
x = self.decoder(x)
return x