pytorch_tabular icon indicating copy to clipboard operation
pytorch_tabular copied to clipboard

add **kwarg to dataloader

Open HernandoR opened this issue 1 year ago • 6 comments

Is your feature request related to a problem? Please describe. Consider setting persistent_workers=True in 'train_dataloader' to speed up the dataloader worker initialization.

Describe the solution you'd like pass '**kwarg to dataloader creation'

Describe alternatives you've considered add 'persistent_workers' option in Dataconfig

Additional context NA

HernandoR avatar Mar 06 '24 05:03 HernandoR

This absolutely makes sense and is a pretty simple feature request as well..

manujosephv avatar Mar 09 '24 02:03 manujosephv

I also support the importance of this feature. I was using pytorch_tabular models for a benchmark paper for VLDB, and for some datasets I got the following error, while training the models:

  File "/ext3/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/ext3/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/ext3/miniconda3/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py", line 176, in forward
    return F.batch_norm(
  File "/ext3/miniconda3/lib/python3.9/site-packages/torch/nn/functional.py", line 2510, in batch_norm
    _verify_batch_size(input.size())
  File "/ext3/miniconda3/lib/python3.9/site-packages/torch/nn/functional.py", line 2478, in _verify_batch_size
    raise ValueError(f"Expected more than 1 value per channel when training, got input size {size}")
ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 6])

The solution was to pass a drop_last argument to DataLoader in the tabular_datamodule.py module. Therefore, adding **kwargs to a DataLoader or DataConfig will help.

denysgerasymuk799 avatar Sep 20 '24 20:09 denysgerasymuk799

Hi, I'd like to work on this as my first issue!

Based on your description, I plan to modify src/pytorch_tabular/tabular_datamodule.py, specifically within the train_loader, val_loader, and prepare_inference_dataloader functions, by adding persistent_workers=self.config.persistent_workers to the DataLoader. Additionally, I will modify the src/pytorch_tabular/config/config.py file to add a boolean variable named persistent_workers to the DataConfig class.

Please let me know if this approach sounds good!

snehilchatterjee avatar Sep 30 '24 19:09 snehilchatterjee

Awesome, thank you!

The first comment in this thread was about persistent_workers. My comment was more about the drop_last parameter. Therefore, I think adding something like dataloader_kwargs to DataConfig will be very useful, similarly as trainer_kwargs works in TrainerConfig (reference). In this case, you can enable users to leverage all the functionality of PyTorch Lightning DataLoader and even its future features.

Alternatively, you can add both persistent_workers and drop_last to DataConfig.

denysgerasymuk799 avatar Sep 30 '24 20:09 denysgerasymuk799

Thanks for the clarification! I’ll proceed by adding dataloader_kwargs to DataConfig, allowing users to pass any parameters as per their requirement.

Would it be alright if I implement these changes and open a pull request?

snehilchatterjee avatar Sep 30 '24 23:09 snehilchatterjee

Yes!! Just raise a PR and I'll review and merge... 😊

manujosephv avatar Oct 01 '24 07:10 manujosephv