stereo-transformer icon indicating copy to clipboard operation
stereo-transformer copied to clipboard

TRAING ERROR

Open GREW-Benchmark opened this issue 3 years ago • 5 comments

when I train the model using batch_size=2, the error is below: please help me, thank you very much!

Traceback (most recent call last): File "main.py", line 251, in main(args_) File "main.py", line 222, in main args.clip_max_norm, amp) File "/mnt/cfs/algorithm/xianda.guo/code/stereo-transformer/utilities/train.py", line 32, in train_one_epoch _, losses, sampled_disp = forward_pass(model, data, device, criterion, train_stats) File "/mnt/cfs/algorithm/xianda.guo/code/stereo-transformer/utilities/foward_pass.py", line 56, in forward_pass outputs = model(inputs) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/mnt/cfs/algorithm/xianda.guo/code/stereo-transformer/module/sttr.py", line 97, in forward attn_weight = self.transformer(feat_left, feat_right, pos_enc) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/mnt/cfs/algorithm/xianda.guo/code/stereo-transformer/module/transformer.py", line 110, in forward attn_weight = self._alternating_attn(feat, pos_enc, pos_indexes, hn) File "/mnt/cfs/algorithm/xianda.guo/code/stereo-transformer/module/transformer.py", line 79, in _alternating_attn pos_indexes) File "/opt/conda/lib/python3.7/site-packages/torch/utils/checkpoint.py", line 211, in checkpoint return CheckpointFunction.apply(function, preserve, *args) File "/opt/conda/lib/python3.7/site-packages/torch/utils/checkpoint.py", line 90, in forward outputs = run_function(*args) File "/mnt/cfs/algorithm/xianda.guo/code/stereo-transformer/module/transformer.py", line 74, in custom_cross_attn return module(*inputs, False) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/mnt/cfs/algorithm/xianda.guo/code/stereo-transformer/module/transformer.py", line 185, in forward pos_indexes=pos_indexes)[0] File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/mnt/cfs/algorithm/xianda.guo/code/stereo-transformer/module/attention.py", line 106, in forward attn_pos_feat = torch.einsum('vnec,wvec->newv', k, q_r) # NxExWxW' File "/opt/conda/lib/python3.7/site-packages/torch/functional.py", line 299, in einsum return _VF.einsum(equation, operands) # type: ignore[attr-defined] RuntimeError: einsum(): operands do not broadcast with remapped shapes [original->remapped]: [71, 360, 8, 16]->[360, 8, 1, 71, 16] [213, 213, 8, 16]->[1, 8, 213, 213, 16]

GREW-Benchmark avatar Nov 23 '21 09:11 GREW-Benchmark

Hi @GREW-Benchmark

The code currently doesn't support batch size larger than 1 because of the random cropping. You can comment it out for larger batch size.

mli0603 avatar Nov 23 '21 16:11 mli0603

Thank you for your reply. I have solve this problem and make the code support larger batch size. I want to know why you used random cropping? for data augmentation?

GREW-Benchmark avatar Nov 24 '21 03:11 GREW-Benchmark

Hi @GREW-Benchmark

Thanks for the update. There were two reasons the code doesn't support larger batch size:

  • I only have one gpu and cannot perform larger experiments
  • Compared to fixed size cropping, random cropping avoids overfitting due to the positional encoding (otherwise it sees the same patterns)

I hope this helps!

mli0603 avatar Nov 29 '21 16:11 mli0603

Hi @GREW-Benchmark

The code currently doesn't support batch size larger than 1 because of the random cropping. You can comment it out for larger batch size.

Hi!! I am very new to this. Can you please let me know how to remove random cropping and make this work for larger batch sizes? Thanks in advance!!

Vaishnavi-1712 avatar Dec 06 '21 17:12 Vaishnavi-1712

stack expects each tensor to be equal size, but got [531, 674, 3] at entry 0 and [380, 885, 3] at entry 1 这是我的问题 楼主有什么办法吗

dididichufale avatar Apr 06 '22 09:04 dididichufale