ffcv
ffcv copied to clipboard
Using ToTorchImage with channels_last=True causes error with drop_last=False
A minimum reproducible example is attached below. I am guessing that there is some kind of pre-computation of size going on that is in some way incompatible with drop_last=False
import numpy as np
import torch
from PIL import Image
# Create a dummy dataset with 50 64x64 images and 10 classes
n_classes = 10
n_images = 50
n_channels = 3
image_size = 64
images = np.random.randint(0, 255, (n_images, n_channels, image_size, image_size))
labels = np.repeat(np.arange(n_classes), n_images // n_classes)
class DummyDataset(torch.utils.data.Dataset):
def __init__(self, images, labels):
self.images = images
self.labels = labels
def __getitem__(self, index):
return Image.fromarray(self.images[index], mode='RGB'), self.labels[index]
def __len__(self):
return len(self.images)
dataset = DummyDataset(images, labels)
from ffcv.writer import DatasetWriter
from ffcv.fields import RGBImageField, IntField
writer = DatasetWriter('dummy.beton', {
'image': RGBImageField(),
'label': IntField()
})
writer.from_indexed_dataset(dataset)
from ffcv.loader import Loader, OrderOption
from ffcv.fields.decoders import SimpleRGBImageDecoder, FloatDecoder, IntDecoder
from ffcv.transforms import ToDevice, ToTorchImage, ToTensor, Convert
CHANNELS_LAST = False # Change me to True to see the error!
PIPELINES = {
'image': [SimpleRGBImageDecoder(), ToTensor(), ToTorchImage(channels_last=CHANNELS_LAST, convert_back_int16=False), ToDevice(0)],
'label': [IntDecoder(), ToTensor(), ToDevice(0)]
}
train_loader = Loader('dummy.beton',
batch_size=15,
num_workers=10,
order=OrderOption.RANDOM,
pipelines=PIPELINES,
drop_last=False)
BATCH_DIM, CHANNEL_DIM, HEIGHT_DIM, WIDTH_DIM = 0, 1, 2, 3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for _ in range(50):
for batch in (train_loader):
data, _ = batch
data = data.float()
As a followup, changing the pipeline to [SimpleRGBImageDecoder(), NormalizeImage(np.array([0, 0, 0]), np.array([255, 255, 255]), np.float16), ToTensor(), ToTorchImage(channels_last=CHANNELS_LAST, convert_back_int16=False), Convert(torch.float16), ToDevice(0)],
(only added in Normalization) makes the error go away.