Anime-Super-Resolution
Anime-Super-Resolution copied to clipboard
你这代码怎么乱写,可以学习一下pytorch的数据读取dataloader类,你写的惨不忍睹。
不是骂你啊,误会了就不好了。
` def len(self): return len(self.images)
def __getitem__(self, index):
image = self.images[index]
return self.pair(image)
def batches(self, batch_size=8):
images = self.images
while True:
batch_index = np.random.randint(len(images), size=batch_size)
x, y = zip(*[self.pair(self.images[index]) for index in batch_index])
x = np.stack(x, axis=0) # Low Res
y = np.stack(y, axis=0) # High Res
yield x, y`
额,我的数据集图片size大小不一样,有时图片大小比裁剪大小小会报错,所以写的麻烦了一点。。然后我不是一次性把所有图片读进来的,所以没做队列。。我是存地址筛batch然后降采样的,数据集比较大的话一次性读的话内存吃不消
还有就是,piar的话我是随机用不同的方法裁剪做降采样做数据增强,就是同一个batch里的每张图片是用的不同的增强方法,上面的方法是同一个batch只能一种pair方法,所以还是得把batch拆开对每张图片用不同方法做裁剪和降采样
然后就是上面的np.random.randint虽然比较简洁,但是会使送入网络的数据比较随机(样本被迭代的次数不均,有随机性),个人感觉还是用np.random.shuffle打乱然后for循环会好一些,可以保证所有数据都能平均的送入网络进行训练
这位老哥真的一点都不客气啊 哈哈哈哈哈 笑死我了
有没有实验数据表格?r_blur-scale-PSNR 以及有没有考虑加jpeg噪声