`with_format` behavior is inconsistent on different datasets
Describe the bug
I found a case where with_format does not transform the dataset to the requested format.
Steps to reproduce the bug
Run:
from transformers import AutoTokenizer, AutoFeatureExtractor
from datasets import load_dataset
raw = load_dataset("glue", "sst2", split="train")
raw = raw.select(range(100))
tokenizer = AutoTokenizer.from_pretrained("philschmid/tiny-bert-sst2-distilled")
def preprocess_func(examples):
return tokenizer(examples["sentence"], padding=True, max_length=256, truncation=True)
data = raw.map(preprocess_func, batched=True)
print(type(data[0]["input_ids"]))
data = data.with_format("torch", columns=["input_ids"])
print(type(data[0]["input_ids"]))
printing as expected:
<class 'list'>
<class 'torch.Tensor'>
Then run:
raw = load_dataset("beans", split="train")
raw = raw.select(range(100))
preprocessor = AutoFeatureExtractor.from_pretrained("nateraw/vit-base-beans")
def preprocess_func(examples):
imgs = [img.convert("RGB") for img in examples["image"]]
return preprocessor(imgs)
data = raw.map(preprocess_func, batched=True)
print(type(data[0]["pixel_values"]))
data = data.with_format("torch", columns=["pixel_values"])
print(type(data[0]["pixel_values"]))
Printing, unexpectedly
<class 'list'>
<class 'list'>
Expected results
with_format should transform into the requested format; it's not the case.
Actual results
type(data[0]["pixel_values"]) should be torch.Tensor in the example above
Environment info
datasetsversion: dev version, commit 44af3fafb527302282f6b6507b952de7435f0979- Platform: Linux
- Python version: 3.9.12
- PyArrow version: 7.0.0
Hi! You can get a torch.Tensor if you do the following:
raw = load_dataset("beans", split="train")
raw = raw.select(range(100))
preprocessor = AutoFeatureExtractor.from_pretrained("nateraw/vit-base-beans")
from datasets import Array3D
features = raw.features.copy()
features["pixel_values"] = datasets.Array3D(shape=(3, 224, 224), dtype="float32")
def preprocess_func(examples):
imgs = [img.convert("RGB") for img in examples["image"]]
return preprocessor(imgs)
data = raw.map(preprocess_func, batched=True, features=features)
print(type(data[0]["pixel_values"]))
data = data.with_format("torch", columns=["pixel_values"])
print(type(data[0]["pixel_values"]))
The reason for this "inconsistency" in the default case is the way PyArrow infers the type of multi-dim arrays (in this case, the pixel_values column). If the type is not specified manually, PyArrow assumes it is a dynamic-length sequence (it needs to know the type before writing the first batch to a cache file, and it can't be sure the array is fixed ahead of time; ArrayXD is our way of telling that the dims are fixed), so it already fails to convert the corresponding array to NumPy properly (you get an array of np.object arrays). And with_format("torch") replaces NumPy arrays with Torch tensors, so this bad formatting propagates.