SA-Text: Simple but Accurate Detector for Text of Arbitrary Shapes


  • Python 2.7
  • PyTorch v0.4.1+
  • pyclipper
  • Polygon2
  • opencv-python 3.4
  • TVM-0.7dev(Optional)


Regression with gaussian map to detect text accurate. img


python --arch resnet50 --batch_size 4 --root_dir $data_root_dir  


python --root_dir $data_root_dir  --resume checkpoints/ic15_resnet50_bs_4_ep_xxx/checkpoint.pth.tar  --gpus 1


Training_data: MTWI dataset

img img img

TVM Optimized Graph

Here is example of TVM-optimized graph in tvm_optimize, we use tvm to optimize onnx-format model

Generate onnx model

python --root_dir $data_root_dir  --resume checkpoints/ic15_resnet50_bs_4_ep_xxx/checkpoint.pth.tar  --gpus 1 --onnx 1

Optimize with onnx-format model

Here onnx opset is opset-9, some error will be raised, you can fix it as follows in file python/tvm/relay/frontend/

For Upsample, scales is not in attr, here I add scales into except

class Upsample(OnnxOpConverter):
    """ Operator converter for Upsample (nearest mode).

    def _impl_v9(cls, inputs, attr, params):
        scales = attr.get('scales')
        if not scales:
            #Here we are going to higher OPSET version.
            assert len(inputs) == 2, "Upsample op take 2 inputs, {} given".format(len(inputs))
                scales = params[inputs[1].name_hint].asnumpy()
                scales = [1., 1., 2., 2.]
            inputs = inputs[:1]
        assert len(scales) == 4 and scales[0] == 1.0 and scales[1] == 1.0
        mode = attr.get('mode')
        if mode == b'nearest':
            method = "nearest_neighbor"
        elif mode == b'linear':
            method = "bilinear"
            raise tvm.error.OpAttributeInvalid(
                'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode))
        attr = {'scale_h': scales[-2], 'scale_w': scales[-1], 'method': method,
                'layout': 'NCHW', 'align_corners': True}
        return AttrCvt('upsampling')(inputs, attr)

For Slice, ends and starts are not in inputs and in attr, so we add them into except by getting from attrs

class Slice(OnnxOpConverter):
    """ Operator converter for Slice.

    def _common(cls, starts, ends, axes):
        new_axes = []
        new_starts = []
        new_ends = []
        pop_index = 0
        for i in range(max(axes) + 1):
            if i in axes:
                pop_index += 1
        return new_starts, new_ends, new_axes

    def _impl_v1(cls, inputs, attr, params):
        if isinstance(attr['starts'], int):
            attr['starts'] = (attr['starts'],)
            attr['ends'] = (attr['ends'],)

            # Update the starts and ends according to axes if required.
            if isinstance(attr['axes'], int):
                attr['axes'] = (attr['axes'],)
            if (max(attr['axes']) + 1) != len(attr['axes']):
                new_starts, new_ends, new_axes = cls._common(
                    attr['starts'], attr['ends'], attr['axes'])
                attr['axes'] = new_axes
                attr['starts'] = new_starts
                attr['ends'] = new_ends
        except KeyError:

        return AttrCvt('strided_slice',
                       transforms={'starts': 'begin',
                                   'ends': 'end'},
                       ignores=['axes'])(inputs, attr)

    def _impl_v10(cls, inputs, attr, params):
            starts = params[get_name(inputs[1])].asnumpy()
            ends = params[get_name(inputs[2])].asnumpy()
            starts = attr['starts']
            ends = attr['ends']
        # Update the starts and ends according to axes if required.
        if len(inputs) >= 4:
                axes = params[get_name(inputs[3])].asnumpy()
                axes = attr['axes']
            if max(axes + 1) != len(axes):
                new_starts, new_ends, _ = cls._common(
                    starts, ends, axes)
                starts = new_starts
                ends = new_ends
        return _op.strided_slice(inputs[0], begin=starts, end=ends)


Before optimize graph, here need to star rpc server for autotvm by following commands:

 python -m tvm.exec.rpc_tracker --host= --port=9190 
 # start new terminal window
 CUDA_VISIBLE_DEVICES=1 python -m tvm.exec.rpc_server --tracker= --key=p100

More deitals about autotvm, you can find documents about it Auto-tuning a convolutional network for NVIDIA GPU

# you can set different target for graph
python tvm_optimize/
# Inference testing
python tvm_optimize/


Here are three steps to optimize graph by tensorrt:

  • onnx-simplifier: Simplify onnx model
  • Generate engine
  • TensorRT(7.0) Python-API for inference

Simplify onnx model:

pip3 install onnx-simplifier
python -m onnxsim textdetection_satext.onnx textdetection_satext_sim.onnx --input-shape 1,3,1024,1024

Generate engine for onnx model:

python textdetection_satext_sim.onnx textdetection_satext.plan

Inference with engine:

python icpr_dataset/  textdetection_satext.plan

Inference time

Just for network inference time with shape(512x512), as following table shows:

pytorch tvm Tensorrt-python-API
network-inference 20ms 10ms 13ms

Differences from original paper

Here are two differences from paper: postprogress algorithm and outputs of network.

here are two outputs of networks: border map and guassian map. Border map is used to separate from two text instances, and gaussian map is used to generate text center region. For afraid of two text center regions are attach, so we can use a border map to delete these pixels that are in two instances border; then we use text center region to generate text instances, finally, we expand text instances by dilating in opencv.