datasets icon indicating copy to clipboard operation
datasets copied to clipboard

Precision being changed from float16 to float32 unexpectedly

Open gcervantes8 opened this issue 1 year ago • 2 comments

Describe the bug

I'm loading a HuggingFace Dataset for images.

I'm running a preprocessing (map operation) step that runs a few operations, one of them being conversion to float16. The Dataset features also say that the 'img' is of type float16. Whenever I take an image from that HuggingFace Dataset instance, the type turns out to be float32.

Steps to reproduce the bug

import torchvision.transforms.v2 as transforms
from datasets import load_dataset

dataset = load_dataset('cifar10', split='test')
dataset = dataset.with_format("torch")

data_transform = transforms.Compose([transforms.Resize((32, 32)), 
                                     transforms.ToDtype(torch.float16, scale=True),
                                     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
                                     ])
def _preprocess(examples):
    # Permutes from (BS x H x W x C) to (BS x C x H x W)
    images = torch.permute(examples['img'], (0, 3, 2, 1))
    examples['img'] = data_transform(images)
    return examples

dataset = dataset.map(_preprocess, batched=True, batch_size=8)

Now at this point the dataset.features are showing float16 which is great because that's what I want.

print(data_loader.features['img'])

Sequence(feature=Sequence(feature=Sequence(feature=Value(dtype='float16', id=None), length=-1, id=None), length=-1, id=None), length=-1, id=None)

But when I try to sample an image from this dataloader; I'm getting a float32 image, when I'm expecting float16:

print(next(iter(data_loader))['img'].dtype)

torch.float32

Expected behavior

I'm expecting the images loaded after the transformation to stay in float16.

Environment info

  • datasets version: 2.18.0
  • Platform: Linux-5.15.146.1-microsoft-standard-WSL2-x86_64-with-glibc2.31
  • Python version: 3.10.9
  • huggingface_hub version: 0.21.4
  • PyArrow version: 14.0.2
  • Pandas version: 2.0.3
  • fsspec version: 2023.10.0

gcervantes8 avatar Mar 23 '24 20:03 gcervantes8

This is because of the formatter (torch in this case). It defaults to float32.

You can load it in float16 using dataset.set_format("torch", dtype=torch.float16).

Modexus avatar Apr 10 '24 15:04 Modexus