torchgeo icon indicating copy to clipboard operation
torchgeo copied to clipboard

Easier way to use Data Processing steps outside of datamodule

Open nilsleh opened this issue 1 year ago • 7 comments

Summary

Normalization and Augmentations are defined in the on_after_batch_transfer() function of the Datamodules to compute them on GPU like recommended from lightning. However, a downside of this is that you always have to pass the datamodule ino .fit and .test. Especially, for the latter, it can be convenient to test on separate dataloaders, however, those are then just "raw" dataloaders without normalization etc. being applied. Took me a minute to find that this was the reason for the funky test results. Currently, I am writing a custom collate_fn and set it to the dataloader that I am getting from a datamodule, however, it would be nice if this could be handled more easily. Open to hear thoughts about this, or suggestions for an easier ways to handle this than what I am doing at the moment.

Rationale

Sometimes I would like to test a model on different datasets and if a torchgeo datamodule is available, it is convenient to just retrieve a configured dataloder from an implemented datamodule.

Implementation

Maybe it could be possible to add a flag to return a dataloader with a collate function based on the on_afer_batch_transfer augmentation.

Alternatives

Currently I am doing something like this:

datamodule = ETCI2021DataModule(root=".", download=True, num_workers=4, batch_size=32)
datamodule.setup("fit")


def collate(batch: list[dict[str, torch.Tensor]]):
    """Collate fn to include augmentations."""
    images = [item["image"] for item in batch]
    labels = [item["label"] for item in batch]

    inputs = torch.stack(images)
    targets = torch.stack(labels)
    return datamodule.on_after_batch_transfer({"image": inputs, "mask": targets})

val_dataloader = datamodule.val_dataloader()
val_dataloader.collate_fn = collate

nilsleh avatar Dec 18 '23 08:12 nilsleh

I can understand why you would want to be able to use a dataset if a data module doesn't exist, but why would you want to use a dataset if a data module does exist?

adamjstewart avatar Dec 18 '23 10:12 adamjstewart

In order to do trainer.validate(model, dataloaders=datamodule.val_dataloader()) but not having to implement my own normalization scheme as a collate fn for every dataloader from a datamodule I want to use. So for example say I train one model and want to validate it on a bunch of datasets, then I could pass multiple dataloaders from different datasets or datamodules to trainer.validate()

nilsleh avatar Dec 18 '23 10:12 nilsleh

But why not use trainer.validate(model, datamodule=datamodule) for all data modules?

adamjstewart avatar Dec 18 '23 11:12 adamjstewart

If you pass a datamodule, it will only select the predefined validation loader and validate on that, but maybe I would like to validate on the train set and the validation set, for example when taking a pre-trained model and checking performance without training. Might also be relevant if you try something like cross validation, where you split your train/val sets. In my case, I am trying conformal prediction, where you need to take a subset of the validation set to create a separate calibration set and use the the model with that, so you need to control "which" split dataloader to apply validation to.

nilsleh avatar Dec 18 '23 12:12 nilsleh

I think at a minimum we should improve the documentation to state when augmentations are and are not applied. For example I assumed they are performed in the dataset get_item, but they are not

robmarkcole avatar Jun 20 '24 15:06 robmarkcole

Just want to clarify that for the majority of datamodules there are no augmentations applied, only a normalization of the images. We try not to prescribe which augmentations should be used for what dataset as this should be left to the user.

There are 2 options:

  • override the datamodule and the train/val/test/predict aug attributes to define your own Kornia AugmentationSequential pipeline (this performs batched GPU augmentations)
  • override the datamodule and its setup method to pass in your own transforms to the train/val/test datasets (this performs augs at the getitem level for per sample augmentation). This option will work using the dataloader e.g. datamodule.train_dataloader() without passing the datamodule to trainer.fit/test

isaaccorley avatar Jun 20 '24 19:06 isaaccorley

Thanks for the clarification - it makes sense that there are basic augmentations to always apply during training and which we don't need to inspect (ie normalisation) and others to experiment with and visualise for sanity checking. Therefore is it a reasonable workflow to pass the later kind to the dataset in setup, and still have the former applied at the data module level?

robmarkcole avatar Jun 20 '24 19:06 robmarkcole