torchgeo icon indicating copy to clipboard operation
torchgeo copied to clipboard

Pass datamodule kwargs to datasets

Open adamjstewart opened this issue 3 years ago • 0 comments

Summary

We should remove dataset-specific arguments from datamodules and instead pass them directly to the dataset through kwargs.

Rationale

As an example, we'll look at TropicalCycloneWindEstimation and CycloneDataModule.

TropicalCycloneWindEstimation has a number of options that users may want to configure:

  • root: root directory containing data
  • split: train or test
  • transforms: data augmentations to apply
  • download: whether or not to download the dataset
  • api_key: MLHUB_API_KEY
  • checksum: whether or not to checksum the download

However, CycloneDataModule only exposes a subset of these:

  • root_dir: different name for some reason
  • api_key: MLHUB_API_KEY

If a user wants to, for example, automatically download the dataset, they have to modify the source code of CycloneDataModule. If we instead pass the **kwargs from CycloneDataModule directly to TropicalCycloneWindEstimation, we end up with less code duplication, more features, and greater consistency.

Implementation

Here is what the change would look like for CycloneDataModule:

diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py
index 9a97242..32db799 100644
--- a/torchgeo/datamodules/cyclone.py
+++ b/torchgeo/datamodules/cyclone.py
@@ -26,11 +26,9 @@ class CycloneDataModule(pl.LightningDataModule):
 
     def __init__(
         self,
-        root_dir: str,
         seed: int,
         batch_size: int = 64,
         num_workers: int = 0,
-        api_key: Optional[str] = None,
         **kwargs: Any,
     ) -> None:
         """Initialize a LightningDataModule for NASA Cyclone based DataLoaders.
@@ -45,11 +43,10 @@ class CycloneDataModule(pl.LightningDataModule):
                 downloaded
         """
         super().__init__()  # type: ignore[no-untyped-call]
-        self.root_dir = root_dir
         self.seed = seed
         self.batch_size = batch_size
         self.num_workers = num_workers
-        self.api_key = api_key
+        self.kwargs = kwargs
 
     def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
         """Transform a single sample from the Dataset.
@@ -76,12 +73,7 @@ class CycloneDataModule(pl.LightningDataModule):
         This includes optionally downloading the dataset. This is done once per node,
         while :func:`setup` is done once per GPU.
         """
-        TropicalCycloneWindEstimation(
-            self.root_dir,
-            split="train",
-            download=self.api_key is not None,
-            api_key=self.api_key,
-        )
+        TropicalCycloneWindEstimation(split="train", **self.kwargs)
 
     def setup(self, stage: Optional[str] = None) -> None:
         """Create the train/val/test splits based on the original Dataset objects.

This should be done for all other datamodules.

Alternatives

The alternative is to add new parameters to each datamodule as we need them. This results in an inconsistent API.

Additional information

No response

adamjstewart avatar Jul 10 '22 21:07 adamjstewart