doctr
doctr copied to clipboard
feat: Improved training scripts in for classification and obj_detection
This PR introduces the following modifications:
- transforms: updated the input and output signature of
RandomRotate
- character classification: expanded data augmentations for PyTorch
- obj detection: switched StepLR & SGD to OneCycleLR & Adam, added data augmentations
- recognition: expanded data augmentation for PyTorch & TensorFlow
Any feedback is welcome!
@charlesmindee can you check if this PR is still good please? If so, can we move forward and merge it? If not, anyone's gonna work on that to fix it or should we close it?
@frgfm Do you want to continue on this if you have time ? :) I think it would be great if we can fix both (sar and master) in the near future :D I can take also one of this let me know what you think :smiley:
@felixdittrich92 yes, at least, I'm happy to help! Before then, because that might take some time, perhaps I can suggest:
- ensuring whether there is an implementation in one of the frameworks that works
- if none works, for now, we remove the model from the "main" branch
- if one of them works, for now, we only remove the one that aren't working
I don't think it's a good idea that we make another release with non-working models :/ What do you think @charlesmindee ?
@frgfm I agree totally not to do another release with broken models ! Especially the MASTER implementation is really troublesome i think in TF AND PT (take a look at the Onnx Draft PR)
I think before we iterate on any new feature like onnx we have to decide what to hold / remove for the moment maybe both sar and MASTER (tf and pt) for the moment ?
#893 sar and master works on TF not well also
@frgfm Do you have also faiced the following problem within SarDecoder AttentionModule:
/opt/conda/conda-bld/pytorch_1646755903507/work/aten/src/ATen/native/cuda/Indexing.cu:703: indexSelectLargeIndex: block: [68,0,0], thread: [31,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
Traceback (most recent call last):
File "/home/felix/Desktop/doctr/references/recognition/train_pytorch.py", line 447, in <module>
main(args)
File "/home/felix/Desktop/doctr/references/recognition/train_pytorch.py", line 367, in main
fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb, amp=args.amp)
File "/home/felix/Desktop/doctr/references/recognition/train_pytorch.py", line 117, in fit_one_epoch
train_loss = model(images, targets)['loss']
File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/felix/Desktop/doctr/doctr/models/recognition/sar/pytorch.py", line 214, in forward
decoded_features = self.decoder(features, encoded, gt=None if target is None else gt)
File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/felix/Desktop/doctr/doctr/models/recognition/sar/pytorch.py", line 131, in forward
logits = self.attention_module(_symbol, features, hidden_state, cell_state)
File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/felix/Desktop/doctr/doctr/models/recognition/sar/pytorch.py", line 66, in forward
attn_query = self.state_conv(tile_hidden_state) # bsz * attn_size * 1 * 1
File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 447, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 443, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: cuDNN error: CUDNN_STATUS_INTERNAL_ERROR
You can try to repro this exception using the following code snippet. If that doesn't trigger the error, please include your original repro script when reporting this issue.
? :/ So i think your Conv replacement with linear layers is not correct But the most changes i have tried ends in this *** message :sweat_smile:
EDIT: I have some ugly (but working) test code here: https://github.com/felixdittrich92/doctr/tree/fix This one trains well (decoding without gt still not fixed) but it refers also to same problems we have saw on TF side (on a whole document it seems to work but with the RecognitionPredictor results still not good)
I have also take a bit time to check if the backbone is maybe troublesome but at this all seems to be fine
@frgfm Do you have also faiced the following problem within SarDecoder AttentionModule:
/opt/conda/conda-bld/pytorch_1646755903507/work/aten/src/ATen/native/cuda/Indexing.cu:703: indexSelectLargeIndex: block: [68,0,0], thread: [31,0,0] Assertion `srcIndex < srcSelectDimSize` failed. Traceback (most recent call last): File "/home/felix/Desktop/doctr/references/recognition/train_pytorch.py", line 447, in <module> main(args) File "/home/felix/Desktop/doctr/references/recognition/train_pytorch.py", line 367, in main fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb, amp=args.amp) File "/home/felix/Desktop/doctr/references/recognition/train_pytorch.py", line 117, in fit_one_epoch train_loss = model(images, targets)['loss'] File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "/home/felix/Desktop/doctr/doctr/models/recognition/sar/pytorch.py", line 214, in forward decoded_features = self.decoder(features, encoded, gt=None if target is None else gt) File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "/home/felix/Desktop/doctr/doctr/models/recognition/sar/pytorch.py", line 131, in forward logits = self.attention_module(_symbol, features, hidden_state, cell_state) File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "/home/felix/Desktop/doctr/doctr/models/recognition/sar/pytorch.py", line 66, in forward attn_query = self.state_conv(tile_hidden_state) # bsz * attn_size * 1 * 1 File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 447, in forward return self._conv_forward(input, self.weight, self.bias) File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 443, in _conv_forward return F.conv2d(input, weight, bias, self.stride, RuntimeError: cuDNN error: CUDNN_STATUS_INTERNAL_ERROR You can try to repro this exception using the following code snippet. If that doesn't trigger the error, please include your original repro script when reporting this issue.
? :/ So i think your Conv replacement with linear layers is not correct But the most changes i have tried ends in this *** message 😅
EDIT: I have some ugly (but working) test code here: https://github.com/felixdittrich92/doctr/tree/fix This one trains well (decoding without gt still not fixed) but it refers also to same problems we have saw on TF side (on a whole document it seems to work but with the RecognitionPredictor results still not good)
I have also take a bit time to check if the backbone is maybe troublesome but at this all seems to be fine
That does ring a bell but not specifically during SAR training, I already had that TF error in the past. But it's been a while 🤷
@frgfm maybe rename to improve training scripts ? :)
Yup! Let me update this :+1:
@frgfm If you found some spare time could we finish this ? 😄 It seems that @jonathanMindee is currently not interested in providing checkpoints for the models which doesn't have pretrained ones. So I started to create a "hopefully" qualitatively equivalent data set. (It will probably take a while since my weeks are pretty full right now). But to get to the point, this PR would be super helpful before I start training. 👍🏼 If everything goes well, I would also make the dataset freely available.
I think we can close this now otherwise feel free to open again :)