metriclearningbench icon indicating copy to clipboard operation
metriclearningbench copied to clipboard

ValueError: sampler should be an instance of torch.utils.data.Sampler,

Open bbrattoli opened this issue 6 years ago • 5 comments

Dear vadimkantorov,

thank you for your publishing this nice repo, very well written. I'm running "python train.py --dataset cub2011 --model margin --base resnet50" with pytorch 1.0.1 and pythorn 3.6 but it crushes with the error

Traceback (most recent call last): File "train.py", line 71, in loader_train = torch.utils.data.DataLoader(dataset_train, sampler = adapt_sampler(opts.batch, dataset_train, opts.sampler), num_workers = opts.threads, batch_size = opts.batch, drop_last = True, pin_memory = True) File "/export/home/bbrattol/anaconda2/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 805, in init batch_sampler = BatchSampler(sampler, batch_size, drop_last) File "/export/home/bbrattol/anaconda2/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 146, in init .format(sampler)) ValueError: sampler should be an instance of torch.utils.data.Sampler, but got sampler=<main. object at 0x7f199b08a9b0>

I guess it has something to do with the new pytorch version. Could you help me to make it run correctly?

Thanks

bbrattoli avatar Mar 11 '19 10:03 bbrattoli

Could you try to modify:

adapt_sampler = lambda batch, dataset, sampler, **kwargs: type('', (), dict(__len__ = dataset.__len__, __iter__ = lambda _: itertools.chain.from_iterable(sampler(batch, dataset, **kwargs))))()

to read instead:

adapt_sampler = lambda batch, dataset, sampler, **kwargs: type('', (torch.utils.data.Sampler,), dict(__len__ = dataset.__len__, __iter__ = lambda _: itertools.chain.from_iterable(sampler(batch, dataset, **kwargs))))()

?

Please let me know if it works and don't hesitate to send a PR.

vadimkantorov avatar Mar 11 '19 11:03 vadimkantorov

I get this error now

Traceback (most recent call last): File "train.py", line 77, in loader_train = torch.utils.data.DataLoader(dataset_train, sampler = adapt_sampler(opts.batch, dataset_train, opts.sampler), num_workers = opts.threads, batch_size = opts.batch, drop_last = True, pin_memory = True) File "train.py", line 71, in dict(len = dataset.len, iter = lambda _: itertools.chain.from_iterable( TypeError: type.new() argument 2 must be tuple, not type

I know it's just a missing parameter but I don't understand what's happening in this piece of code, so please help! :D

bbrattoli avatar Mar 11 '19 12:03 bbrattoli

Just checking, are you using (torch.utils.data.Sampler,) and not (torch.utils.data.Sampler)? (the difference is the comma, but it's important)

vadimkantorov avatar Mar 11 '19 13:03 vadimkantorov

It gives me the same error in both cases

bbrattoli avatar Mar 11 '19 14:03 bbrattoli

Sorry @bbrattoli, don't have time to look at this in detail these days. I'll update here if I check what's going on. Meanwhile, the way to go is to define yourself a Sampler subclass instead of my hacky adapt_sampler dynamic class creation.

vadimkantorov avatar Mar 12 '19 20:03 vadimkantorov