cutmix icon indicating copy to clipboard operation
cutmix copied to clipboard

Getting empty dataset on using Cutmix Dataloader

Open IamSparky opened this issue 4 years ago • 0 comments

I have used this class for creting dataset class for my flower data

defining dataset

from PIL import Image import cv2 import albumentations import torch import numpy as np import io from torch.utils.data import Dataset

class FlowerDataset(Dataset):
    def __init__(self, id , classes , image , img_height , img_width, mean , std , is_valid):
        self.id = id
        self.classes = classes
        self.image = image
        if is_valid == 1:
            self.aug = albumentations.Compose([
               albumentations.Resize(img_height , img_width, always_apply = True) ,
               albumentations.Normalize(mean , std , always_apply = True) 
            ])
        else:
            self.aug = albumentations.Compose([
                albumentations.Resize(img_height , img_width, always_apply = True) ,
                albumentations.Normalize(mean , std , always_apply = True),
                albumentations.ShiftScaleRotate(shift_limit = 0.0625,
                                                scale_limit = 0.1 ,
                                                rotate_limit = 5,
                                                p = 0.9)
            ]) 
        
    def __len__(self):
        return len(self.id)
    
    def __getitem__(self, index):
        id = self.id[index]
        img = np.array(Image.open(io.BytesIO(self.image[index]))) 
        img = cv2.resize(img, dsize=(128, 128), interpolation=cv2.INTER_CUBIC)
        img = self.aug(image = img)['image']
        img = np.transpose(img , (2,0,1)).astype(np.float32)
       
        
        return {
            'image' : torch.tensor(img, dtype = torch.float),
            'class' : torch.tensor(self.classes[index], dtype = torch.long) 
        } 

then did the sanity check to ensure its good to go

# sanity check for FlowerDataset class created

train_dataset = FlowerDataset(id = train_ids, classes = train_class, image = train_images, 
                        img_height = 128 , img_width = 128, 
                        mean = (0.485, 0.456, 0.406),
                        std = (0.229, 0.224, 0.225) , is_valid = 0)

import matplotlib.pyplot as plt
%matplotlib inline

idx = 0
img = train_dataset[idx]['image']

print(train_dataset[idx]['class'])

npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1,2,0)))

image

& then I did

# setting up the dataloader with cutmix data agumentation
!pip install git+https://github.com/ildoonet/cutmix

# setting up the train data loader

from cutmix.cutmix import CutMix

train_dataloader = CutMix(train_dataset, 
                          num_class=104, 
                          beta=1.0, 
                          prob=0.5, 
                          num_mix=2)

It worked successfully. but when I did the sanity check as:-->

batch = next(iter(train_dataloader))
len(batch)

it returned image

and thereby I am unable to train the model

IamSparky avatar Jul 24 '20 10:07 IamSparky