deq
deq copied to clipboard
Expected a 'cuda' device type for generator (related to speed issues?)
Heya, thanks for the great paper(s) :)
Initially I've had to fix a few things to make your code run, but now I find it very slow and I'm wondering if I broke anything.
The cls_mdeq_LARGE_reg.yaml
experiment runs at 130 samples/s post pretraining on a GTX 2080, which means that it takes hours to reach ~90% test acc (while a WideResNet will take 10min for that perf).
The main error I had to fix was this:
Traceback (most recent call last):
File "/afs/inf.ed.ac.uk/user/s17/s1771851/git/official_untouched/MDEQ-Vision/tools/cls_train.py", line 257, in <module>
main()
File "/afs/inf.ed.ac.uk/user/s17/s1771851/git/official_untouched/MDEQ-Vision/tools/cls_train.py", line 220, in main
final_output_dir, tb_log_dir, writer_dict, topk=topk)
File "/afs/inf.ed.ac.uk/user/s17/s1771851/git/official_untouched/MDEQ-Vision/tools/../lib/core/cls_function.py", line 42, in train
for i, (input, target) in enumerate(train_loader):
File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 359, in __iter__
return self._get_iterator() [8/202]
File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 305, in _get_iterator
return _MultiProcessingDataLoaderIter(self)
File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 944, in __init__
self._reset(loader, first_iter=True)
File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 975, in _reset
self._try_put_index()
File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1209, in _try_put_index
index = self._next_index()
File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 512, in _next_index
return next(self._sampler_iter) # may raise StopIteration
File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 229, in __iter__
for idx in self.sampler:
File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 126, in __iter__
yield from torch.randperm(n, generator=generator).tolist()
RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
which according to this issue seems to be caused by this line in your code: torch.set_default_tensor_type('torch.cuda.FloatTensor')
which I removed. After setting all the needed things on .cuda() manually I get the performance mentionned above. Is this normal or did I break something? Thanks!
Specs Pytorch 1.10 Windows (RTX3070) and ubuntu 20 (GTX 2080) both tried
Hi @polo5 ,
Thanks for your interest in our paper! Yes, MDEQ-Large does take a few hours to finish all epochs, so your calculation is correct. However, I should note that most (about 80%?) of the time was actually spent on boosting accuracy from 90% to ~93.5%. If you are just looking for a 90% accuracy, an MDEQ-Large should achieve that within 1.5-2 hours.
If you use an MDEQ-Tiny model and set AUGMENT
in the config file to be True
(while slightly increasing the # of epochs), you should also expect near 90% accuracy even more quickly. But of course, you are using a smaller model.
Also re: error: I haven't encountered this error for this repo before, but I'll check for sure for PyTorch 1.10!
I got this error as well when I use PyTorch 1.10. After changing to PyTorch 1.8.1, everything is fine. You can take a look at this issue. Seems related to a bug in pytorch 1.10.
Thanks @liu-jc.
This is an important issue then, since pytorch 1.10 is the recommended version for this repo. The other issue is that previous pytorch versions (<1.10) do work on Ubuntu but somehow don't work on Windows for this repo (I get some strange error-less interruption which looks like a segmentation fault). Oh well, if setting devices manually hasn't slowed down the code I'm happy with that solution.
The issue with PyTorch <1.10 is that the hook implementation currently used to implement O(1) memory (see e.g., https://github.com/locuslab/deq/blob/master/DEQ-Sequence/models/deq_transformer.py#L380) did not work before, and was a bug that PyTorch only recently fixed in 1.10 (see this issue). I will check this again recently and update on this thread.
I've never tried on Windows environment but I suspect it has something to do with the WSL?