Precision being changed from float16 to float32 unexpectedly
Describe the bug
I'm loading a HuggingFace Dataset for images.
I'm running a preprocessing (map operation) step that runs a few operations, one of them being conversion to float16. The Dataset features also say that the 'img' is of type float16. Whenever I take an image from that HuggingFace Dataset instance, the type turns out to be float32.
Steps to reproduce the bug
import torchvision.transforms.v2 as transforms
from datasets import load_dataset
dataset = load_dataset('cifar10', split='test')
dataset = dataset.with_format("torch")
data_transform = transforms.Compose([transforms.Resize((32, 32)),
transforms.ToDtype(torch.float16, scale=True),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
def _preprocess(examples):
# Permutes from (BS x H x W x C) to (BS x C x H x W)
images = torch.permute(examples['img'], (0, 3, 2, 1))
examples['img'] = data_transform(images)
return examples
dataset = dataset.map(_preprocess, batched=True, batch_size=8)
Now at this point the dataset.features are showing float16 which is great because that's what I want.
print(data_loader.features['img'])
Sequence(feature=Sequence(feature=Sequence(feature=Value(dtype='float16', id=None), length=-1, id=None), length=-1, id=None), length=-1, id=None)
But when I try to sample an image from this dataloader; I'm getting a float32 image, when I'm expecting float16:
print(next(iter(data_loader))['img'].dtype)
torch.float32
Expected behavior
I'm expecting the images loaded after the transformation to stay in float16.
Environment info
datasetsversion: 2.18.0- Platform: Linux-5.15.146.1-microsoft-standard-WSL2-x86_64-with-glibc2.31
- Python version: 3.10.9
huggingface_hubversion: 0.21.4- PyArrow version: 14.0.2
- Pandas version: 2.0.3
fsspecversion: 2023.10.0
This is because of the formatter (torch in this case).
It defaults to float32.
You can load it in float16 using dataset.set_format("torch", dtype=torch.float16).