sketch2model
sketch2model copied to clipboard
TypeError: forward() missing 1 required positional argument: 'image'.
When i use inference , i meet this problem:
File "infer.py", line 25, in <module>
model.inference(current_epoch, dataset_infer, save_dir=out_dir)
File "/home/test/sketch2model/models/view_disentangle_model.py", line 572, in inference
self.forward_inference()
File "/home/test/sketch2model/models/view_disentangle_model.py", line 464, in forward_inference
out = self.netFull(self.data_image, view=self.data_view)
File "/home/test/anaconda3/envs/sketch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/test/anaconda3/envs/sketch/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 161, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/home/test/anaconda3/envs/sketch/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 171, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/home/test/anaconda3/envs/sketch/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
output.reraise()
File "/home/test/anaconda3/envs/sketch/lib/python3.6/site-packages/torch/_utils.py", line 428, in reraise
raise self.exc_type(msg)
TypeError: Caught TypeError in replica 1 on device 1.
Original Traceback (most recent call last):
File "/home/test/anaconda3/envs/sketch/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/home/test/anaconda3/envs/sketch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
TypeError: forward() missing 1 required positional argument: 'image'.
These changes will fix it, apologies for not having an open PR. I assume this change will break the ability to do training though, certainly, parallelism.
祝你有美好的一天
@duzhenjiang113 Sorry for the late reply. It looks like a DataParallel problem. Which version of PyTorch are you using?