darts
darts copied to clipboard
add `DataLoader` related parameters to `fit()` and `predict()`
Checklist before merging this PR:
- [x] Mentioned all issues that this PR fixes or addresses.
- [x] Summarized the updates of this PR under Summary.
- [x] Added an entry under Unreleased in the Changelog.
Summary
Add torch.utils.data.DataLoader related parameters to fit() and predict() of TorchForecastingModel
Codecov Report
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 93.74%. Comparing base (
a0cc279) to head (1b149e4).
Additional details and impacted files
@@ Coverage Diff @@
## master #2295 +/- ##
==========================================
- Coverage 93.75% 93.74% -0.02%
==========================================
Files 138 138
Lines 14352 14341 -11
==========================================
- Hits 13456 13444 -12
- Misses 896 897 +1
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Hi @BohdanBilonoh,
It looks great, however to make it easier to maintain and more exhaustive, I think that it would be great to just add an argument called dataloader_kwargs, then check that the argument explicitly used by Darts are not redundant/overwritten and then pass this argument down to the DataLoader constructor.
It will allow users to specify more than just prefetch_factor, persistent_workers and pin_memory, while limiting copy-pasting from other library documentation (putting a link to the torch.DataLoader page does sound like a good idea for this argument however)
PS: Apologies for taking so long with the review of this PR.
Hi @BohdanBilonoh, Would you please add multiprocessing_context parameter for Dataloader, it is useful when we use multi-workers for dataloader, Thanks!
Hi @BohdanBilonoh, Would you please add multiprocessing_context parameter for Dataloader, it is useful when we use multi-workers for dataloader, Thanks!
@BohdanBilonoh refer to #2375
Hi @BohdanBilonoh, Would you please add multiprocessing_context parameter for Dataloader, it is useful when we use multi-workers for dataloader, Thanks!
@BohdanBilonoh refer to #2375
@BohdanBilonoh My bad, it is good idea from @madtoinou to add dataloader_kwargs to let user input dataloader parameters as wish freely, not need to support special multiprocessing_context parameter forcibly
@madtoinou what do you think about hardcoded parameters like
batch_size=self.batch_size,
shuffle=True,
drop_last=False,
collate_fn=self._batch_collate_fn,
should it be hard coded for new dataloader_kwargs like
def _setup_for_train(
self,
train_dataset: TrainingDataset,
val_dataset: Optional[TrainingDataset] = None,
trainer: Optional[pl.Trainer] = None,
verbose: Optional[bool] = None,
epochs: int = 0,
dataloader_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[pl.Trainer, PLForecastingModule, DataLoader, Optional[DataLoader]]:
...
if dataloader_kwargs is None:
dataloader_kwargs = {}
dataloader_kwargs["shuffle"] = True
dataloader_kwargs["batch_size"] = self.batch_size
dataloader_kwargs["drop_last"] = False
dataloader_kwargs["collate_fn"] = self._batch_collate_fn
# Setting drop_last to False makes the model see each sample at least once, and guarantee the presence of at
# least one batch no matter the chosen batch size
train_loader = DataLoader(
train_dataset,
**dataloader_kwargs,
)
dataloader_kwargs["shuffle"] = False
# Prepare validation data
val_loader = (
None
if val_dataset is None
else DataLoader(
val_dataset,
**dataloader_kwargs,
)
)
...
or give a user full control on dataloader_kwargs?
@madtoinou what do you think about hardcoded parameters like
batch_size=self.batch_size, shuffle=True, drop_last=False, collate_fn=self._batch_collate_fn,should it be hard coded for new
dataloader_kwargslikedef _setup_for_train( self, train_dataset: TrainingDataset, val_dataset: Optional[TrainingDataset] = None, trainer: Optional[pl.Trainer] = None, verbose: Optional[bool] = None, epochs: int = 0, dataloader_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[pl.Trainer, PLForecastingModule, DataLoader, Optional[DataLoader]]: ... if dataloader_kwargs is None: dataloader_kwargs = {} dataloader_kwargs["shuffle"] = True dataloader_kwargs["batch_size"] = self.batch_size dataloader_kwargs["drop_last"] = False dataloader_kwargs["collate_fn"] = self._batch_collate_fn # Setting drop_last to False makes the model see each sample at least once, and guarantee the presence of at # least one batch no matter the chosen batch size train_loader = DataLoader( train_dataset, **dataloader_kwargs, ) dataloader_kwargs["shuffle"] = False # Prepare validation data val_loader = ( None if val_dataset is None else DataLoader( val_dataset, **dataloader_kwargs, ) ) ...or give a user full control on
dataloader_kwargs?
you could extend your suggestion to allow overrides but with populated defaults
defaults = dict(shuffle = True, batch_size = self.batch_size, drop_last = False, collate_fn = self._batch_collate_fn)
#combine with defaults but override them
dataloader_kwargs_train = {**defaults,**dataloader_kwargs}
#override shuffle
dataloader_kwargs_val = (**dataloader_kwargs_train, **dict(shuffle=False)}