DB-tf
DB-tf copied to clipboard
A Tensorflow implementation of "Real-time Scene Text Detection with Differentiable Binarization"
DB: Real-time Scene Text Detection with Differentiable Binarization
Introduction
This is a TensorFlow implementation of "Real-time Scene Text Detection with Differentiable Binarization".
Part of the code is inherited from DB.
ToDo List
- [x] Release trained models
- [x] Training code
- [x] Inference code
- [x] Muti gpu training
- [x] Tensorboard support
- [x] Exp another train losses
- [ ] Eval code
- [x] Data augmentation(crop and random img aug)
- [x] More backbones
- [x] Add dilation conv(ASPP layer)
- [ ] Deformable Convolutional Networks
Install
pip install -r requirements.txt
Test
1.Download model.
Model | Download link |
---|---|
ResNet-50 | BaiduYun, GoogleDrive |
ResNet-50-ASPP | BaiduYun, GoogleDrive |
2.Config network
revise the db_config.py
cfg.BACKBONE = 'resnet_v1_50'
# if trained model name does not have aspp, it should be False.
cfg.ASPP_LAYER = False
3.Start to test img.
python inference.py --gpuid='0' --ckptpath='path' --imgpath='img.jpg'
Samples show
org show | poly show | bbox show |
---|---|---|
![]() |
![]() |
![]() |
binarize_map | threshold_map | thresh_binary |
![]() |
![]() |
![]() |
Dataset
This repo is train on CTW1500 dataset. Download from BaiduYun (key:yjiz) or OneDrive.
Training model
1. Get the CTW1500 train images path and labels path.
revise the db_config.py
# Train data config
cfg.TRAIN.IMG_DIR = '/path/ctw1500/train/text_image'
cfg.TRAIN.LABEL_DIR = '/path/ctw1500/train/text_label_curve'
# Val or test data config
cfg.EVAL.IMG_DIR = '/path/ctw1500/test/text_image'
cfg.EVAL.LABEL_DIR = '/path/ctw1500/test/text_label_circum'
2. Muti gpu train and config network.
revise the db_config.py
# only support 'resnet_v1_50' and 'resnet_v1_18'
cfg.BACKBONE = 'resnet_v1_50'
# if you want to train aspp network, it should be True
cfg.ASPP_LAYER = False
cfg.TRAIN.VIS_GPU = '5,6' # single gpu -> '0'
3. Save train logs and models.
revise the db_config.py
cfg.TRAIN.TRAIN_LOGS = '/path/tf_logs'
cfg.TRAIN.CHECKPOINTS_OUTPUT_DIR = '/path/ckpt'
4. Pretrain or restore model.
If you want to pretrain model, revise the db_config.py
cfg.TRAIN.RESTORE = False
cfg.TRAIN.PRETRAINED_MODEL_PATH = 'pretrain model path'
If you want to restore model, revise the db_config.py
cfg.TRAIN.RESTORE = True
cfg.TRAIN.RESTORE_CKPT_PATH = 'checkpoint path'
5. Start to train.
python train.py
6. Tensorboard show
cd 'tensorboard path'
tensorboard --logdir=./
Red line is train logs, blue line is val logs.
Losses show
binarize loss | threshold loss | threshold binary loss |
---|---|---|
![]() |
![]() |
![]() |
model_loss | total_loss | |
![]() |
![]() |
Acc show
binarize acc | threshold binary acc |
---|---|
![]() |
![]() |
Experiment
Test on RTX 2080 Ti.
BackBone | ASPP | Input Size | Infernce Time(ms) | PostProcess Time(ms) | FPS |
---|---|---|---|---|---|
ResNet-50 | × | 320 | 13.3 | 2.9 | 61.7 |
ResNet-50 | × | 512 | 19.2 | 4.5 | 42.2 |
ResNet-50 | × | 640 | 28.9 | 5.2 | 29.3 |
ResNet-50 | × | 736 | 33.2 | 5.7 | 25.7 |
ResNet-18 | × | 320 | 12.2 | 2.9 | 66.2 |
ResNet-18 | × | 512 | 16.9 | 4.5 | 46.7 |
ResNet-18 | × | 736 | 32.7 | 5.7 | 26 |
ResNet-50 | √ | 640 | 32.6 | --- | --- |