learning-notes
learning-notes copied to clipboard
在 PyTorch 中加载数据时,如何让不同的数据集共享同样的随机变换?
例如,有 A、B 两个域的数据,两个域的数据是不成对的,数目有可能不一致。在加载数据的时候,我们从 A、B 两个域中各随机取出一幅图像,组成一对。之后需要对这两幅图像做一些随机变换操作,我们希望对取出的这两幅图像做相同的随机变换。如何实现比较好呢?
思路 1
开始的思路是定义两个 Dataset
,然后让它们共享同一个 random transform 的单例。给单例的 __call__()
方法设置一个参数,Dataset A 调用时更新随机变换参数,Dataset B 调用时不更新随机变换参数。
实现时,开始尝试用单例实现,不太对(可能是我单例没写好)。后来用传入同一个对象来实现,结果也不对。还没找到原因。
思路 2
用同一个 Dataset
,传入两个数据集的路径,__getitem__
函数返回两幅图像。
使用这个思路的好处是,对随机变换的处理更加容易。不过需要使用一个唯一的 index
来从两个列表中随机选取图像,稍微有点别扭,应该不会影响随机性。