when I train the model using batch_size=2, the error is below:
please help me, thank you very much!
Traceback (most recent call last):
File "main.py", line 251, in
main(args_)
File "main.py", line 222, in main
args.clip_max_norm, amp)
File "/mnt/cfs/algorithm/xianda.guo/code/stereo-transformer/utilities/train.py", line 32, in train_one_epoch
_, losses, sampled_disp = forward_pass(model, data, device, criterion, train_stats)
File "/mnt/cfs/algorithm/xianda.guo/code/stereo-transformer/utilities/foward_pass.py", line 56, in forward_pass
outputs = model(inputs)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/mnt/cfs/algorithm/xianda.guo/code/stereo-transformer/module/sttr.py", line 97, in forward
attn_weight = self.transformer(feat_left, feat_right, pos_enc)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/mnt/cfs/algorithm/xianda.guo/code/stereo-transformer/module/transformer.py", line 110, in forward
attn_weight = self._alternating_attn(feat, pos_enc, pos_indexes, hn)
File "/mnt/cfs/algorithm/xianda.guo/code/stereo-transformer/module/transformer.py", line 79, in _alternating_attn
pos_indexes)
File "/opt/conda/lib/python3.7/site-packages/torch/utils/checkpoint.py", line 211, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
File "/opt/conda/lib/python3.7/site-packages/torch/utils/checkpoint.py", line 90, in forward
outputs = run_function(*args)
File "/mnt/cfs/algorithm/xianda.guo/code/stereo-transformer/module/transformer.py", line 74, in custom_cross_attn
return module(*inputs, False)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/mnt/cfs/algorithm/xianda.guo/code/stereo-transformer/module/transformer.py", line 185, in forward
pos_indexes=pos_indexes)[0]
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/mnt/cfs/algorithm/xianda.guo/code/stereo-transformer/module/attention.py", line 106, in forward
attn_pos_feat = torch.einsum('vnec,wvec->newv', k, q_r) # NxExWxW'
File "/opt/conda/lib/python3.7/site-packages/torch/functional.py", line 299, in einsum
return _VF.einsum(equation, operands) # type: ignore[attr-defined]
RuntimeError: einsum(): operands do not broadcast with remapped shapes [original->remapped]: [71, 360, 8, 16]->[360, 8, 1, 71, 16] [213, 213, 8, 16]->[1, 8, 213, 213, 16]
Hi @GREW-Benchmark
The code currently doesn't support batch size larger than 1 because of the random cropping. You can comment it out for larger batch size.
Thank you for your reply. I have solve this problem and make the code support larger batch size. I want to know why you used random cropping? for data augmentation?
Hi @GREW-Benchmark
Thanks for the update. There were two reasons the code doesn't support larger batch size:
- I only have one gpu and cannot perform larger experiments
- Compared to fixed size cropping, random cropping avoids overfitting due to the positional encoding (otherwise it sees the same patterns)
I hope this helps!
Hi @GREW-Benchmark
The code currently doesn't support batch size larger than 1 because of the random cropping. You can comment it out for larger batch size.
Hi!! I am very new to this.
Can you please let me know how to remove random cropping and make this work for larger batch sizes?
Thanks in advance!!
stack expects each tensor to be equal size, but got [531, 674, 3] at entry 0 and [380, 885, 3] at entry 1 这是我的问题 楼主有什么办法吗