add **kwarg to dataloader
Is your feature request related to a problem? Please describe.
Consider setting persistent_workers=True in 'train_dataloader' to speed up the dataloader worker initialization.
Describe the solution you'd like pass '**kwarg to dataloader creation'
Describe alternatives you've considered
add 'persistent_workers' option in Dataconfig
Additional context NA
This absolutely makes sense and is a pretty simple feature request as well..
I also support the importance of this feature. I was using pytorch_tabular models for a benchmark paper for VLDB, and for some datasets I got the following error, while training the models:
File "/ext3/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/ext3/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/ext3/miniconda3/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py", line 176, in forward
return F.batch_norm(
File "/ext3/miniconda3/lib/python3.9/site-packages/torch/nn/functional.py", line 2510, in batch_norm
_verify_batch_size(input.size())
File "/ext3/miniconda3/lib/python3.9/site-packages/torch/nn/functional.py", line 2478, in _verify_batch_size
raise ValueError(f"Expected more than 1 value per channel when training, got input size {size}")
ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 6])
The solution was to pass a drop_last argument to DataLoader in the tabular_datamodule.py module. Therefore, adding **kwargs to a DataLoader or DataConfig will help.
Hi, I'd like to work on this as my first issue!
Based on your description, I plan to modify src/pytorch_tabular/tabular_datamodule.py, specifically within the train_loader, val_loader, and prepare_inference_dataloader functions, by adding persistent_workers=self.config.persistent_workers to the DataLoader. Additionally, I will modify the src/pytorch_tabular/config/config.py file to add a boolean variable named persistent_workers to the DataConfig class.
Please let me know if this approach sounds good!
Awesome, thank you!
The first comment in this thread was about persistent_workers. My comment was more about the drop_last parameter. Therefore, I think adding something like dataloader_kwargs to DataConfig will be very useful, similarly as trainer_kwargs works in TrainerConfig (reference). In this case, you can enable users to leverage all the functionality of PyTorch Lightning DataLoader and even its future features.
Alternatively, you can add both persistent_workers and drop_last to DataConfig.
Thanks for the clarification! I’ll proceed by adding dataloader_kwargs to DataConfig, allowing users to pass any parameters as per their requirement.
Would it be alright if I implement these changes and open a pull request?
Yes!! Just raise a PR and I'll review and merge... 😊