deq icon indicating copy to clipboard operation
deq copied to clipboard

Expected a 'cuda' device type for generator (related to speed issues?)

Open polo5 opened this issue 3 years ago • 5 comments

Heya, thanks for the great paper(s) :)

Initially I've had to fix a few things to make your code run, but now I find it very slow and I'm wondering if I broke anything. The cls_mdeq_LARGE_reg.yaml experiment runs at 130 samples/s post pretraining on a GTX 2080, which means that it takes hours to reach ~90% test acc (while a WideResNet will take 10min for that perf).

The main error I had to fix was this:

Traceback (most recent call last):
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/git/official_untouched/MDEQ-Vision/tools/cls_train.py", line 257, in <module>
    main()
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/git/official_untouched/MDEQ-Vision/tools/cls_train.py", line 220, in main
    final_output_dir, tb_log_dir, writer_dict, topk=topk)
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/git/official_untouched/MDEQ-Vision/tools/../lib/core/cls_function.py", line 42, in train
    for i, (input, target) in enumerate(train_loader):
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 359, in __iter__
 return self._get_iterator()                                                                                                                                                                                                                                                                                    [8/202]
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 305, in _get_iterator
    return _MultiProcessingDataLoaderIter(self)
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 944, in __init__
    self._reset(loader, first_iter=True)
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 975, in _reset
    self._try_put_index()
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1209, in _try_put_index
    index = self._next_index()
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 512, in _next_index
    return next(self._sampler_iter)  # may raise StopIteration
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 229, in __iter__
    for idx in self.sampler:
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 126, in __iter__
    yield from torch.randperm(n, generator=generator).tolist()
RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'

which according to this issue seems to be caused by this line in your code: torch.set_default_tensor_type('torch.cuda.FloatTensor') which I removed. After setting all the needed things on .cuda() manually I get the performance mentionned above. Is this normal or did I break something? Thanks!

Specs Pytorch 1.10 Windows (RTX3070) and ubuntu 20 (GTX 2080) both tried

polo5 avatar Feb 18 '22 22:02 polo5

Hi @polo5 ,

Thanks for your interest in our paper! Yes, MDEQ-Large does take a few hours to finish all epochs, so your calculation is correct. However, I should note that most (about 80%?) of the time was actually spent on boosting accuracy from 90% to ~93.5%. If you are just looking for a 90% accuracy, an MDEQ-Large should achieve that within 1.5-2 hours.

If you use an MDEQ-Tiny model and set AUGMENT in the config file to be True (while slightly increasing the # of epochs), you should also expect near 90% accuracy even more quickly. But of course, you are using a smaller model.

jerrybai1995 avatar Feb 18 '22 22:02 jerrybai1995

Also re: error: I haven't encountered this error for this repo before, but I'll check for sure for PyTorch 1.10!

jerrybai1995 avatar Feb 18 '22 22:02 jerrybai1995

I got this error as well when I use PyTorch 1.10. After changing to PyTorch 1.8.1, everything is fine. You can take a look at this issue. Seems related to a bug in pytorch 1.10.

liu-jc avatar Feb 20 '22 06:02 liu-jc

Thanks @liu-jc.

This is an important issue then, since pytorch 1.10 is the recommended version for this repo. The other issue is that previous pytorch versions (<1.10) do work on Ubuntu but somehow don't work on Windows for this repo (I get some strange error-less interruption which looks like a segmentation fault). Oh well, if setting devices manually hasn't slowed down the code I'm happy with that solution.

polo5 avatar Feb 22 '22 14:02 polo5

The issue with PyTorch <1.10 is that the hook implementation currently used to implement O(1) memory (see e.g., https://github.com/locuslab/deq/blob/master/DEQ-Sequence/models/deq_transformer.py#L380) did not work before, and was a bug that PyTorch only recently fixed in 1.10 (see this issue). I will check this again recently and update on this thread.

I've never tried on Windows environment but I suspect it has something to do with the WSL?

jerrybai1995 avatar Feb 23 '22 04:02 jerrybai1995