torchgeo icon indicating copy to clipboard operation
torchgeo copied to clipboard

Problem training with standard dataloader.

Open lcoandrade opened this issue 2 years ago • 5 comments

Description

I've just learnt about Torchgeo and got interested in using it. So, I created a Kaggle notebook to test it with NAIP and Chesapeake data (Torchgeo 101). When I try to train a segmentation task, I get the following error: ValueError: A frozen dataclass was passed to `apply_to_collection` but this is not allowed.

Steps to reproduce

  1. Create a dataset with NAIP and Chesapeake data:
# Creating the NAIP dataset
naip_root = os.path.join(INPUT_DIR, 'naip')
naip = NAIP(naip_root)

# Creating the CHESAPEAKE dataset
chesapeake_root = os.path.join(INPUT_DIR, "chesapeake")
chesapeake = ChesapeakeDE(
    chesapeake_root, 
    crs=naip.crs, 
    res=naip.res, 
    download=False
)
  1. Make an intersection, create a sampler and a dataloader:
dataset = naip & chesapeake
sampler = RandomGeoSampler(dataset, size=IMG_SIZE, length=SAMPLE_SIZE)
dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)
  1. Define a trainer:
DEVICE, NUM_DEVICES = ("cuda", torch.cuda.device_count()) if torch.cuda.is_available() else ("cpu", mp.cpu_count())
WORKERS = mp.cpu_count()
print(f'Running on {NUM_DEVICES} {DEVICE}(s)')

trainer = pl.Trainer(
        accelerator=DEVICE,
        devices=NUM_DEVICES,
        max_epochs=EPOCHS,
        callbacks=[checkpoint_callback, ],
        logger=logger,
    )
  1. Define a segmentation task:
ssl._create_default_https_context = ssl._create_unverified_context

test_dir = os.path.join(OUTPUT_DIR, "test")
if not os.path.exists(test_dir):
    os.makedirs(test_dir)
    
logger = CSVLogger(
    test_dir, 
    name='torchgeo_logs'
)

checkpoint_callback = ModelCheckpoint(
    every_n_epochs=1,
    dirpath=test_dir,
    filename='torchgeo_trained'
)

task = SemanticSegmentationTask(
    model = SEGMENTATION_MODEL,
    backbone = BACKBONE,
    weights = WEIGHTS,
    in_channels = IN_CHANNELS,
    num_classes = NUM_CLASSES,
    loss = LOSS,
    ignore_index = None,
    learning_rate = LR,
    learning_rate_schedule_patience = PATIENCE, 
)
  1. Start training:
trainer.fit(
        model=task, 
        train_dataloaders=dataloader,
    )

Version

0.4.1

lcoandrade avatar Jun 19 '23 03:06 lcoandrade

Duplicate of #1056 and #1418

The issue is that some of the sample values returned by GeoDataset can't be automatically collated by PyTorch (BoundingBox, CRS). Our solution for our builtin data modules is to remove these values before loading: https://github.com/microsoft/torchgeo/blob/v0.4.1/torchgeo/datamodules/geo.py#L280

My suggestion would be to write a simple data module (there are dozens of builtin examples) and use that instead of directly using a data loader. Maybe this is something we could add to our collation functions...

adamjstewart avatar Jun 19 '23 17:06 adamjstewart

Is this still an issue or can this be closed?

adamjstewart avatar Sep 06 '23 20:09 adamjstewart

I've made a CustomGeoDatamodule like this:

class CustomGeoDataModule(GeoDataModule):
    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either 'fit', 'validate', 'test', or 'predict'.
        """
        self.dataset = self.dataset_class(**self.kwargs)
        
        generator = torch.Generator().manual_seed(0)
        (
            self.train_dataset,
            self.val_dataset,
            self.test_dataset,
        ) = random_bbox_assignment(dataset, [0.6, 0.2, 0.2], generator)
        
        if stage in ["fit"]:
            self.train_batch_sampler = RandomBatchGeoSampler(
                self.train_dataset, self.patch_size, self.batch_size, self.length
            )
        if stage in ["fit", "validate"]:
            self.val_sampler = GridGeoSampler(
                self.val_dataset, self.patch_size, self.patch_size
            )
        if stage in ["test"]:
            self.test_sampler = GridGeoSampler(
                self.test_dataset, self.patch_size, self.patch_size
            )

To solve my problem.

lcoandrade avatar Sep 06 '23 20:09 lcoandrade

My suggestion would be to write a simple data module (there are dozens of builtin examples) and use that instead of directly using a data loader. Maybe this is something we could add to our collation functions...

Hi @adamjstewart

I've also encountered this problem, and it's taken me a while to find the solution. Definitely +1 for adding this as a feature of torchgeo to make this as seamless as possible for the end-users using GeoDatasets.

Cheers, Tom

trchudley avatar Nov 06 '23 11:11 trchudley

Reopening as a reminder to try to upstream some of our changes to PyTorch.

adamjstewart avatar Nov 06 '23 13:11 adamjstewart