Glow-PyTorch icon indicating copy to clipboard operation
Glow-PyTorch copied to clipboard

Not able to make it work for MNIST images

Open tayalkshitij opened this issue 3 years ago • 0 comments

Hi contributors,

I am trying to make the code work for mnist images but the reconstructions are just random black and white pixels (both additive and affine), although the loss is decreasing. I ran it for more than 1000 epochs but the results are still the same. I have defined a class for loading mnist_image and the final images are of size 32,32,1. Does the code work for single-channel images or are there some other parameters that I have to change? Below is the code which I am using to load mnist images.

class mnist_image(Dataset):
    def __init__(self, data, targets, transform=None):
        self.data = data
        self.targets = torch.LongTensor(targets)
        self.transform = transform

    def __getitem__(self, index):
        x = self.data[index]
        y = self.targets[index]

        if self.transform:
            x = Image.fromarray(self.data[index].astype(np.uint8)).convert('L')
            x = self.transform(x)

        return x, y

    def __len__(self):
        return len(self.data)



from keras.datasets import mnist

def get_mnist_images(augment, dataroot, download):
    image_shape = (32, 32, 1)
    num_classes = 10

    if augment:
        transformations = [transforms.RandomAffine(0, translate=(0.1, 0.1))]
    else:
        transformations = []

    transformations.extend([transforms.Resize(32), transforms.ToTensor(), preprocess])
    train_transform = transforms.Compose(transformations)

    test_transform = transforms.Compose([transforms.Resize(32),transforms.ToTensor(), preprocess])

   (x_train, _), (x_test, _) = mnist.load_data()  # get mnist images from keras

   #load train dataset
    data = list(x_train)
    targets = list(np.random.randint(10, size=(len(data))))
    train_dataset = mnist_image(data, targets, transform=train_transform)

   #load test dataset
    data = list(x_test)
    targets = list(np.random.randint(10, size=(len(data))))
    test_dataset = mnist_image(data, targets, transform=test_transform)

    return image_shape, num_classes, train_dataset, test_dataset

tayalkshitij avatar Jun 18 '21 14:06 tayalkshitij