learn2learn
learn2learn copied to clipboard
Code example from union data set doesn't work
import learn2learn as l2l
train = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="train")
train = l2l.data.MetaDataset(train)
valid = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="validation")
valid = l2l.data.MetaDataset(valid)
test = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="test")
test = l2l.data.MetaDataset(test)
from learn2learn.data import UnionMetaDataset
union = UnionMetaDataset([train, valid, test])
assert len(union.labels) == 100
error:
AttributeError: module 'torchvision.datasets' has no attribute 'CIFARFS'
this runs but assert fails:
from pathlib import Path
import learn2learn as l2l
root = Path("~/data/").expanduser()
# root = Path(".").expanduser()
train = torchvision.datasets.CIFAR100(root=root, train=True, download=True)
train = l2l.data.MetaDataset(train)
print(f'{len(train.labels)=}')
# valid = torchvision.datasets.CIFAR100(root="/tmp/mnist", mode="validation")
# valid = l2l.data.MetaDataset(valid)
test = torchvision.datasets.CIFAR100(root=root, train=False, download=True)
test = l2l.data.MetaDataset(test)
print(f'{len(test.labels)=}')
from learn2learn.data import UnionMetaDataset
# union = UnionMetaDataset([train, valid, test])
union = UnionMetaDataset([train, test])
assert len(union.labels) == 100, f'Error, got instead: {len(union.labels)=}.'
Fails because you need the validation set too to sum to 100.