darts icon indicating copy to clipboard operation
darts copied to clipboard

add `DataLoader` related parameters to `fit()` and `predict()`

Open BohdanBilonoh opened this issue 1 year ago • 5 comments

Checklist before merging this PR:

  • [x] Mentioned all issues that this PR fixes or addresses.
  • [x] Summarized the updates of this PR under Summary.
  • [x] Added an entry under Unreleased in the Changelog.

Summary

Add torch.utils.data.DataLoader related parameters to fit() and predict() of TorchForecastingModel

BohdanBilonoh avatar Mar 27 '24 17:03 BohdanBilonoh

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 93.74%. Comparing base (a0cc279) to head (1b149e4).

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2295      +/-   ##
==========================================
- Coverage   93.75%   93.74%   -0.02%     
==========================================
  Files         138      138              
  Lines       14352    14341      -11     
==========================================
- Hits        13456    13444      -12     
- Misses        896      897       +1     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov[bot] avatar Apr 09 '24 12:04 codecov[bot]

Hi @BohdanBilonoh,

It looks great, however to make it easier to maintain and more exhaustive, I think that it would be great to just add an argument called dataloader_kwargs, then check that the argument explicitly used by Darts are not redundant/overwritten and then pass this argument down to the DataLoader constructor.

It will allow users to specify more than just prefetch_factor, persistent_workers and pin_memory, while limiting copy-pasting from other library documentation (putting a link to the torch.DataLoader page does sound like a good idea for this argument however)

PS: Apologies for taking so long with the review of this PR.

madtoinou avatar May 06 '24 15:05 madtoinou

Hi @BohdanBilonoh, Would you please add multiprocessing_context parameter for Dataloader, it is useful when we use multi-workers for dataloader, Thanks!

joshua-xia avatar May 07 '24 01:05 joshua-xia

Hi @BohdanBilonoh, Would you please add multiprocessing_context parameter for Dataloader, it is useful when we use multi-workers for dataloader, Thanks!

@BohdanBilonoh refer to #2375

joshua-xia avatar May 07 '24 01:05 joshua-xia

Hi @BohdanBilonoh, Would you please add multiprocessing_context parameter for Dataloader, it is useful when we use multi-workers for dataloader, Thanks!

@BohdanBilonoh refer to #2375

@BohdanBilonoh My bad, it is good idea from @madtoinou to add dataloader_kwargs to let user input dataloader parameters as wish freely, not need to support special multiprocessing_context parameter forcibly

joshua-xia avatar May 07 '24 02:05 joshua-xia

@madtoinou what do you think about hardcoded parameters like

batch_size=self.batch_size,
shuffle=True,
drop_last=False,
collate_fn=self._batch_collate_fn,

should it be hard coded for new dataloader_kwargs like

def _setup_for_train(
    self,
    train_dataset: TrainingDataset,
    val_dataset: Optional[TrainingDataset] = None,
    trainer: Optional[pl.Trainer] = None,
    verbose: Optional[bool] = None,
    epochs: int = 0,
    dataloader_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[pl.Trainer, PLForecastingModule, DataLoader, Optional[DataLoader]]:

        ...

        if dataloader_kwargs is None:
            dataloader_kwargs = {}

        dataloader_kwargs["shuffle"] = True
        dataloader_kwargs["batch_size"] = self.batch_size
        dataloader_kwargs["drop_last"] = False
        dataloader_kwargs["collate_fn"] = self._batch_collate_fn

        # Setting drop_last to False makes the model see each sample at least once, and guarantee the presence of at
        # least one batch no matter the chosen batch size
        train_loader = DataLoader(
            train_dataset,
            **dataloader_kwargs,
        )

        dataloader_kwargs["shuffle"] = False

        # Prepare validation data
        val_loader = (
            None
            if val_dataset is None
            else DataLoader(
                val_dataset,
                **dataloader_kwargs,
            )
        )

        ...

or give a user full control on dataloader_kwargs?

BohdanBilonoh avatar May 28 '24 09:05 BohdanBilonoh

@madtoinou what do you think about hardcoded parameters like

batch_size=self.batch_size,
shuffle=True,
drop_last=False,
collate_fn=self._batch_collate_fn,

should it be hard coded for new dataloader_kwargs like

def _setup_for_train(
    self,
    train_dataset: TrainingDataset,
    val_dataset: Optional[TrainingDataset] = None,
    trainer: Optional[pl.Trainer] = None,
    verbose: Optional[bool] = None,
    epochs: int = 0,
    dataloader_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[pl.Trainer, PLForecastingModule, DataLoader, Optional[DataLoader]]:

        ...

        if dataloader_kwargs is None:
            dataloader_kwargs = {}

        dataloader_kwargs["shuffle"] = True
        dataloader_kwargs["batch_size"] = self.batch_size
        dataloader_kwargs["drop_last"] = False
        dataloader_kwargs["collate_fn"] = self._batch_collate_fn

        # Setting drop_last to False makes the model see each sample at least once, and guarantee the presence of at
        # least one batch no matter the chosen batch size
        train_loader = DataLoader(
            train_dataset,
            **dataloader_kwargs,
        )

        dataloader_kwargs["shuffle"] = False

        # Prepare validation data
        val_loader = (
            None
            if val_dataset is None
            else DataLoader(
                val_dataset,
                **dataloader_kwargs,
            )
        )

        ...

or give a user full control on dataloader_kwargs?

you could extend your suggestion to allow overrides but with populated defaults

defaults = dict(shuffle = True, batch_size = self.batch_size, drop_last = False, collate_fn = self._batch_collate_fn)
#combine with defaults but override them
dataloader_kwargs_train = {**defaults,**dataloader_kwargs}
#override shuffle
dataloader_kwargs_val = (**dataloader_kwargs_train, **dict(shuffle=False)} 

tRosenflanz avatar May 30 '24 17:05 tRosenflanz