datasets
datasets copied to clipboard
Unable to use dataset with PyTorch dataloader
Describe the bug
When using .with_format("torch")
, an arrow table is returned and I am unable to use it by passing it to a PyTorch DataLoader: please see the code below.
Steps to reproduce the bug
from datasets import load_dataset
from torch.utils.data import DataLoader
ds = load_dataset(
"para_crawl",
name="enfr",
cache_dir="/tmp/test/",
split="train",
keep_in_memory=True,
)
dataloader = DataLoader(ds.with_format("torch"), num_workers=32)
print(next(iter(dataloader)))
Is there something I am doing wrong? The documentation does not say much about the behavior of .with_format()
so I feel like I am a bit stuck here :-/
Thanks in advance for your help!
Expected results
The code should run with no error
Actual results
AttributeError: 'str' object has no attribute 'dtype'
Environment info
-
datasets
version: 2.3.2 - Platform: Linux-4.18.0-348.el8.x86_64-x86_64-with-glibc2.28
- Python version: 3.10.4
- PyArrow version: 8.0.0
- Pandas version: 1.4.3
Hi! para_crawl
has a single column of type Translation
, which stores translation dictionaries. These dictionaries can be stored in a NumPy array but not in a PyTorch tensor since PyTorch only supports numeric types. In datasets
, the conversion to torch
works as follows:
- convert PyArrow table to NumPy arrays
- convert NumPy arrays to Torch tensors.
The 2nd step is problematic for your case as datasets
attempts to convert the array of dictionaries to a PyTorch tensor. One way to fix this is to use the preprocessing logic from the Transformers translation script. And on our side, I think we can replace a NumPy array of dicts with a dict of NumPy array if the feature type is Translation
/TranslationVariableLanguages
(one array for each language) to get the official PyTorch error message for strings in such case.