doctr icon indicating copy to clipboard operation
doctr copied to clipboard

feat: Improved training scripts in for classification and obj_detection

Open fg-mindee opened this issue 2 years ago • 8 comments

This PR introduces the following modifications:

  • transforms: updated the input and output signature of RandomRotate
  • character classification: expanded data augmentations for PyTorch
  • obj detection: switched StepLR & SGD to OneCycleLR & Adam, added data augmentations
  • recognition: expanded data augmentation for PyTorch & TensorFlow

Any feedback is welcome!

fg-mindee avatar Mar 21 '22 16:03 fg-mindee

@charlesmindee can you check if this PR is still good please? If so, can we move forward and merge it? If not, anyone's gonna work on that to fix it or should we close it?

fharper avatar Apr 08 '22 18:04 fharper

@frgfm Do you want to continue on this if you have time ? :) I think it would be great if we can fix both (sar and master) in the near future :D I can take also one of this let me know what you think :smiley:

felixdittrich92 avatar Apr 27 '22 20:04 felixdittrich92

@felixdittrich92 yes, at least, I'm happy to help! Before then, because that might take some time, perhaps I can suggest:

  • ensuring whether there is an implementation in one of the frameworks that works
  • if none works, for now, we remove the model from the "main" branch
  • if one of them works, for now, we only remove the one that aren't working

I don't think it's a good idea that we make another release with non-working models :/ What do you think @charlesmindee ?

frgfm avatar May 07 '22 10:05 frgfm

@frgfm I agree totally not to do another release with broken models ! Especially the MASTER implementation is really troublesome i think in TF AND PT (take a look at the Onnx Draft PR)

I think before we iterate on any new feature like onnx we have to decide what to hold / remove for the moment maybe both sar and MASTER (tf and pt) for the moment ?

#893 sar and master works on TF not well also

felixdittrich92 avatar May 07 '22 10:05 felixdittrich92

@frgfm Do you have also faiced the following problem within SarDecoder AttentionModule:

/opt/conda/conda-bld/pytorch_1646755903507/work/aten/src/ATen/native/cuda/Indexing.cu:703: indexSelectLargeIndex: block: [68,0,0], thread: [31,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
Traceback (most recent call last):
  File "/home/felix/Desktop/doctr/references/recognition/train_pytorch.py", line 447, in <module>
    main(args)
  File "/home/felix/Desktop/doctr/references/recognition/train_pytorch.py", line 367, in main
    fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb, amp=args.amp)
  File "/home/felix/Desktop/doctr/references/recognition/train_pytorch.py", line 117, in fit_one_epoch
    train_loss = model(images, targets)['loss']
  File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/felix/Desktop/doctr/doctr/models/recognition/sar/pytorch.py", line 214, in forward
    decoded_features = self.decoder(features, encoded, gt=None if target is None else gt)
  File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/felix/Desktop/doctr/doctr/models/recognition/sar/pytorch.py", line 131, in forward
    logits = self.attention_module(_symbol, features, hidden_state, cell_state)
  File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/felix/Desktop/doctr/doctr/models/recognition/sar/pytorch.py", line 66, in forward
    attn_query = self.state_conv(tile_hidden_state)  # bsz * attn_size * 1 * 1
  File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 447, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 443, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: cuDNN error: CUDNN_STATUS_INTERNAL_ERROR
You can try to repro this exception using the following code snippet. If that doesn't trigger the error, please include your original repro script when reporting this issue.

? :/ So i think your Conv replacement with linear layers is not correct But the most changes i have tried ends in this *** message :sweat_smile:

EDIT: I have some ugly (but working) test code here: https://github.com/felixdittrich92/doctr/tree/fix This one trains well (decoding without gt still not fixed) but it refers also to same problems we have saw on TF side (on a whole document it seems to work but with the RecognitionPredictor results still not good)

I have also take a bit time to check if the backbone is maybe troublesome but at this all seems to be fine

felixdittrich92 avatar May 16 '22 20:05 felixdittrich92

@frgfm Do you have also faiced the following problem within SarDecoder AttentionModule:

/opt/conda/conda-bld/pytorch_1646755903507/work/aten/src/ATen/native/cuda/Indexing.cu:703: indexSelectLargeIndex: block: [68,0,0], thread: [31,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
Traceback (most recent call last):
  File "/home/felix/Desktop/doctr/references/recognition/train_pytorch.py", line 447, in <module>
    main(args)
  File "/home/felix/Desktop/doctr/references/recognition/train_pytorch.py", line 367, in main
    fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb, amp=args.amp)
  File "/home/felix/Desktop/doctr/references/recognition/train_pytorch.py", line 117, in fit_one_epoch
    train_loss = model(images, targets)['loss']
  File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/felix/Desktop/doctr/doctr/models/recognition/sar/pytorch.py", line 214, in forward
    decoded_features = self.decoder(features, encoded, gt=None if target is None else gt)
  File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/felix/Desktop/doctr/doctr/models/recognition/sar/pytorch.py", line 131, in forward
    logits = self.attention_module(_symbol, features, hidden_state, cell_state)
  File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/felix/Desktop/doctr/doctr/models/recognition/sar/pytorch.py", line 66, in forward
    attn_query = self.state_conv(tile_hidden_state)  # bsz * attn_size * 1 * 1
  File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 447, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/felix/.conda/envs/doctr-dev/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 443, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: cuDNN error: CUDNN_STATUS_INTERNAL_ERROR
You can try to repro this exception using the following code snippet. If that doesn't trigger the error, please include your original repro script when reporting this issue.

? :/ So i think your Conv replacement with linear layers is not correct But the most changes i have tried ends in this *** message 😅

EDIT: I have some ugly (but working) test code here: https://github.com/felixdittrich92/doctr/tree/fix This one trains well (decoding without gt still not fixed) but it refers also to same problems we have saw on TF side (on a whole document it seems to work but with the RecognitionPredictor results still not good)

I have also take a bit time to check if the backbone is maybe troublesome but at this all seems to be fine

That does ring a bell but not specifically during SAR training, I already had that TF error in the past. But it's been a while 🤷

frgfm avatar May 22 '22 20:05 frgfm

@frgfm maybe rename to improve training scripts ? :)

felixdittrich92 avatar May 31 '22 08:05 felixdittrich92

Yup! Let me update this :+1:

frgfm avatar Jun 23 '22 15:06 frgfm

@frgfm If you found some spare time could we finish this ? 😄 It seems that @jonathanMindee is currently not interested in providing checkpoints for the models which doesn't have pretrained ones. So I started to create a "hopefully" qualitatively equivalent data set. (It will probably take a while since my weeks are pretty full right now). But to get to the point, this PR would be super helpful before I start training. 👍🏼 If everything goes well, I would also make the dataset freely available.

felixT2K avatar Apr 29 '23 12:04 felixT2K

I think we can close this now otherwise feel free to open again :)

felixdittrich92 avatar Nov 17 '23 21:11 felixdittrich92