PyTorch-Encoding
PyTorch-Encoding copied to clipboard
RuntimeError in PyTorch 1.6
I am using pytorch 1.6.0, CUDA 10.2 and Pytorch_encoding master branch.
Traceback (most recent call last):
File "train_SSL.py", line 612, in <module>
main()
File "train_SSL.py", line 438, in main
pred = F.interpolate((model(images)), size=input_shape, mode='bilinear', align_corners=True)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/parallel/data_parallel.py", line 155, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/parallel/data_parallel.py", line 165, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/usr/local/lib/python3.6/dist-packages/torch/nn/parallel/parallel_apply.py", line 85, in parallel_apply
output.reraise()
File "/usr/local/lib/python3.6/dist-packages/torch/_utils.py", line 395, in reraise
raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
output = module(*input, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "/mnt/10T_1/project/model/deeplabv2.py", line 207, in forward
x = self.bn1(x)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch_encoding-1.2.2b20200908-py3.6-linux-x86_64.egg/encoding/nn/syncbn.py", line 202, in forward
self.activation, self.slope).view(input_shape)
RuntimeError: Some elements marked as dirty during the forward method were not returned as output. The inputs that are modified inplace must all be outputs of the Function.
This seems a bug due to recent PyTorch updates. I haven't tried the code recently. If you want to solve it quickly, you may downgrade PyTorch to 1.4.0.
I found that this issue also appeared in inplace_abn issue#166. They have solved this problem in commit. I have tried this solution in syncbn, and it works.
Awesome, thanks for pointing out the solution @zhangbin0917
I found that this issue also appeared in inplace_abn issue#166. They have solved this problem in commit. I have tried this solution in syncbn, and it works.
Can you share the changes? I didn't make it. Thanks @ zhangbin0917
hello, i meet the same problem, can you tell me how to fix it ?@zhanghang1989 @zhangbin0917
I think i had fix the problem,thanks for @zhangbin0917 advice,my torch==1.7.0
1、fix the code in ....\site-packages\encoding\nn\syncbn.py at about line 200 from return syncbatchnorm(........).view(input_shape) to x, _, _=syncbatchnorm(........) x=x.view(input_shape) return x
2.、fix the code ....\site-packages\encodings\functions\syncbn.py at about line 102 from ctx.save_for_backward(x,_ex,_exs,gamma,beta) return y to ctx.save_for_backward(x,_ex,_exs,gamma,beta) ctx.mark_non_differentiable(running_mean,running_var) return y,running_mean,running_var
3、fix the code ....\site-packages\encodings\functions\syncbn.py at about line 109 from def backward(ctx,dz) to def backward(ctx,dz,_druning_mean,_druning_var)
Hi, @Zhaoguanhua @zhangbin0917 Thanks for your sharing.
However, do you have tried to fix the bug of DistSyncBatchnorm due to the same reason? I tried to modify the code just following your discussion like this: https://github.com/zhangbin0917/PyTorch-Encoding/compare/master...BrandonHanx:master
But it seems like this change does not work for evaluation mode, with this traceback:
File "~/miniconda3/envs/DFF/lib/python3.8/site-packages/encoding/nn/syncbn.py", line 96, in forward
y, _, _ = dist_syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var,
ValueError: too many values to unpack (expected 3)
Any idea about this? Need help :)
the same question, anyone help?