torchgeo
torchgeo copied to clipboard
Pass datamodule kwargs to datasets
Summary
We should remove dataset-specific arguments from datamodules and instead pass them directly to the dataset through kwargs.
Rationale
As an example, we'll look at TropicalCycloneWindEstimation and CycloneDataModule.
TropicalCycloneWindEstimation has a number of options that users may want to configure:
- root: root directory containing data
- split: train or test
- transforms: data augmentations to apply
- download: whether or not to download the dataset
- api_key: MLHUB_API_KEY
- checksum: whether or not to checksum the download
However, CycloneDataModule only exposes a subset of these:
- root_dir: different name for some reason
- api_key: MLHUB_API_KEY
If a user wants to, for example, automatically download the dataset, they have to modify the source code of CycloneDataModule. If we instead pass the **kwargs from CycloneDataModule directly to TropicalCycloneWindEstimation, we end up with less code duplication, more features, and greater consistency.
Implementation
Here is what the change would look like for CycloneDataModule:
diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py
index 9a97242..32db799 100644
--- a/torchgeo/datamodules/cyclone.py
+++ b/torchgeo/datamodules/cyclone.py
@@ -26,11 +26,9 @@ class CycloneDataModule(pl.LightningDataModule):
def __init__(
self,
- root_dir: str,
seed: int,
batch_size: int = 64,
num_workers: int = 0,
- api_key: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for NASA Cyclone based DataLoaders.
@@ -45,11 +43,10 @@ class CycloneDataModule(pl.LightningDataModule):
downloaded
"""
super().__init__() # type: ignore[no-untyped-call]
- self.root_dir = root_dir
self.seed = seed
self.batch_size = batch_size
self.num_workers = num_workers
- self.api_key = api_key
+ self.kwargs = kwargs
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
@@ -76,12 +73,7 @@ class CycloneDataModule(pl.LightningDataModule):
This includes optionally downloading the dataset. This is done once per node,
while :func:`setup` is done once per GPU.
"""
- TropicalCycloneWindEstimation(
- self.root_dir,
- split="train",
- download=self.api_key is not None,
- api_key=self.api_key,
- )
+ TropicalCycloneWindEstimation(split="train", **self.kwargs)
def setup(self, stage: Optional[str] = None) -> None:
"""Create the train/val/test splits based on the original Dataset objects.
This should be done for all other datamodules.
Alternatives
The alternative is to add new parameters to each datamodule as we need them. This results in an inconsistent API.
Additional information
No response