cutmix
cutmix copied to clipboard
Getting empty dataset on using Cutmix Dataloader
I have used this class for creting dataset class for my flower data
defining dataset
from PIL import Image import cv2 import albumentations import torch import numpy as np import io from torch.utils.data import Dataset
class FlowerDataset(Dataset):
def __init__(self, id , classes , image , img_height , img_width, mean , std , is_valid):
self.id = id
self.classes = classes
self.image = image
if is_valid == 1:
self.aug = albumentations.Compose([
albumentations.Resize(img_height , img_width, always_apply = True) ,
albumentations.Normalize(mean , std , always_apply = True)
])
else:
self.aug = albumentations.Compose([
albumentations.Resize(img_height , img_width, always_apply = True) ,
albumentations.Normalize(mean , std , always_apply = True),
albumentations.ShiftScaleRotate(shift_limit = 0.0625,
scale_limit = 0.1 ,
rotate_limit = 5,
p = 0.9)
])
def __len__(self):
return len(self.id)
def __getitem__(self, index):
id = self.id[index]
img = np.array(Image.open(io.BytesIO(self.image[index])))
img = cv2.resize(img, dsize=(128, 128), interpolation=cv2.INTER_CUBIC)
img = self.aug(image = img)['image']
img = np.transpose(img , (2,0,1)).astype(np.float32)
return {
'image' : torch.tensor(img, dtype = torch.float),
'class' : torch.tensor(self.classes[index], dtype = torch.long)
}
then did the sanity check to ensure its good to go
# sanity check for FlowerDataset class created
train_dataset = FlowerDataset(id = train_ids, classes = train_class, image = train_images,
img_height = 128 , img_width = 128,
mean = (0.485, 0.456, 0.406),
std = (0.229, 0.224, 0.225) , is_valid = 0)
import matplotlib.pyplot as plt
%matplotlib inline
idx = 0
img = train_dataset[idx]['image']
print(train_dataset[idx]['class'])
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1,2,0)))
& then I did
# setting up the dataloader with cutmix data agumentation
!pip install git+https://github.com/ildoonet/cutmix
# setting up the train data loader
from cutmix.cutmix import CutMix
train_dataloader = CutMix(train_dataset,
num_class=104,
beta=1.0,
prob=0.5,
num_mix=2)
It worked successfully. but when I did the sanity check as:-->
batch = next(iter(train_dataloader))
len(batch)
it returned
and thereby I am unable to train the model