learn2learn icon indicating copy to clipboard operation
learn2learn copied to clipboard

Code example from union data set doesn't work

Open brando90 opened this issue 3 years ago • 1 comments

    import learn2learn as l2l
    train = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="train")
    train = l2l.data.MetaDataset(train)
    valid = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="validation")
    valid = l2l.data.MetaDataset(valid)
    test = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="test")
    test = l2l.data.MetaDataset(test)
    from learn2learn.data import UnionMetaDataset
    union = UnionMetaDataset([train, valid, test])
    assert len(union.labels) == 100

error:

AttributeError: module 'torchvision.datasets' has no attribute 'CIFARFS'

brando90 avatar Sep 28 '22 18:09 brando90

this runs but assert fails:

    from pathlib import Path
    import learn2learn as l2l

    root = Path("~/data/").expanduser()
    # root = Path(".").expanduser()
    train = torchvision.datasets.CIFAR100(root=root, train=True, download=True)
    train = l2l.data.MetaDataset(train)
    print(f'{len(train.labels)=}')
    # valid = torchvision.datasets.CIFAR100(root="/tmp/mnist", mode="validation")
    # valid = l2l.data.MetaDataset(valid)
    test = torchvision.datasets.CIFAR100(root=root, train=False, download=True)
    test = l2l.data.MetaDataset(test)
    print(f'{len(test.labels)=}')

    from learn2learn.data import UnionMetaDataset
    # union = UnionMetaDataset([train, valid, test])
    union = UnionMetaDataset([train, test])
    assert len(union.labels) == 100, f'Error, got instead: {len(union.labels)=}.'

brando90 avatar Sep 28 '22 18:09 brando90

Fails because you need the validation set too to sum to 100.

seba-1511 avatar Oct 30 '22 04:10 seba-1511