chineseocr icon indicating copy to clipboard operation
chineseocr copied to clipboard

如何在你训练的基础上增加训练集?

Open Steverdeng opened this issue 7 years ago • 25 comments

请问一下,我测试发现ⅡⅢⅣⅤⅥⅠ这些字很难识别出来,请问我如何在你的基础上增加训练这些类型的字?

Steverdeng avatar Sep 10 '18 01:09 Steverdeng

可以进行fine-tune,或者进行迁移学习。或者根据应用场景修正,比如你的应该场景只用到“1234567890”,那么在crnn模型输出的概率矩阵环节,应该先过滤后,再进行后续处理。

wenlihaoyu avatar Sep 10 '18 03:09 wenlihaoyu

这几个字ⅡⅢⅣⅤⅥ 在 代码crnn/keys.py是找不到的。就是不知道如何添加训练

Steverdeng avatar Sep 10 '18 03:09 Steverdeng

你训练出来模型汉字是比较好的,就是英文以及一些字符不太好

Steverdeng avatar Sep 10 '18 03:09 Steverdeng

修改模型的输出层,重新训练即可。

wenlihaoyu avatar Sep 12 '18 15:09 wenlihaoyu

@wenlihaoyu 您好,想问一下fine-tune是将你生成的模型作为预训练模型,然后再训练我们自己的数据集吗?另外如果我想制作自己的label,要怎样进行啊?还有你说的crnn模型输出的概率矩阵环节,应该先过滤后,再进行后续处理,不是很理解这句话啊

ZJU-PLP avatar Sep 13 '18 07:09 ZJU-PLP

这个能识别证件照 啥的嘛

xxllp avatar Sep 14 '18 05:09 xxllp

模型定义

import torch.nn as nn
import torch.nn as nn
import torch.nn.parallel

def data_parallel(model, input, ngpu):
    if isinstance(input.data, torch.cuda.FloatTensor) and ngpu > 1:
        output = nn.parallel.data_parallel(model, input, range(ngpu))
    else:
        output = model(input)
    return output

class BidirectionalLSTM(nn.Module):

    def __init__(self, nIn, nHidden, nOut, ngpu):
        super(BidirectionalLSTM, self).__init__()
        self.ngpu = ngpu

        self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
        self.embedding = nn.Linear(nHidden * 2, nOut)

    def forward(self, input):
        recurrent, _ = utils.data_parallel(
            self.rnn, input, self.ngpu)  # [T, b, h * 2]

        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)
        output = utils.data_parallel(
            self.embedding, t_rec, self.ngpu)  # [T * b, nOut]
        output = output.view(T, b, -1)

        return output


class CRNN(nn.Module):

    def __init__(self, imgH, nc, nclass, nh, ngpu, n_rnn=2, leakyRelu=False):
        super(CRNN, self).__init__()
        self.ngpu = ngpu
        assert imgH % 16 == 0, 'imgH has to be a multiple of 16'

        ks = [3, 3, 3, 3, 3, 3, 2]
        ps = [1, 1, 1, 1, 1, 1, 0]
        ss = [1, 1, 1, 1, 1, 1, 1]
        nm = [64, 128, 256, 256, 512, 512, 512]

        cnn = nn.Sequential()

        def convRelu(i, batchNormalization=False):
            nIn = nc if i == 0 else nm[i - 1]
            nOut = nm[i]
            cnn.add_module('conv{0}'.format(i),
                           nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
            if batchNormalization:
                cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
            if leakyRelu:
                cnn.add_module('relu{0}'.format(i),
                               nn.LeakyReLU(0.2, inplace=True))
            else:
                cnn.add_module('relu{0}'.format(i), nn.ReLU(True))

        convRelu(0)
        cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2))  # 64x16x64
        convRelu(1)
        cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2))  # 128x8x32
        convRelu(2, True)
        convRelu(3)
        cnn.add_module('pooling{0}'.format(2), nn.MaxPool2d((2, 2),
                                                            (2, 1),
                                                            (0, 1)))  # 256x4x16
        convRelu(4, True)
        convRelu(5)
        cnn.add_module('pooling{0}'.format(3), nn.MaxPool2d((2, 2),
                                                            (2, 1),
                                                            (0, 1)))  # 512x2x16
        convRelu(6, True)  # 512x1x16

        self.cnn = cnn
        self.rnn = nn.Sequential(
            BidirectionalLSTM(512, nh, nh, ngpu),
            BidirectionalLSTM(nh, nh, nclass, ngpu)
        )

    def forward(self, input):
        # conv features
        conv = data_parallel(self.cnn, input, self.ngpu)
        b, c, h, w = conv.size()
        assert h == 1, "the height of conv must be 1"
        conv = conv.squeeze(2)
        conv = conv.permute(2, 0, 1)  # [w, b, c]

        # rnn features
        output = utils.data_parallel(self.rnn, conv, self.ngpu)

        return output
def pre_model(nclass,ocrModelPath):
    #@@parm nclass:字符总数
    #@@预训练模型文件
    
    if torch.cuda.is_available() and GPU:
       model = CRNN(32, 1, nclass+1, 256, 1).cuda()
    else:
        model = CRNN(32, 1, nclass+1, 256, 1).cpu()

    state_dict = torch.load(ocrModelPath,map_location=lambda storage, loc: storage)
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k.replace('module.','') # remove `module.`
        new_state_dict[name] = v

    model.load_state_dict(new_state_dict)
    model.eval()
    
    return model

def new_model(nclass,preModel):
    
    #定义你自己的模型
    
    if torch.cuda.is_available() and GPU:
       model = CRNN(32, 1, nclass+1, 256, 1).cuda()
    else:
        model = CRNN(32, 1, nclass+1, 256, 1).cpu()
        
    
    modelDict = model.state_dict()##
    preModelDict = preModel.state_dict()##
    preModelDict = {k: v for k, v in preModelDict.items() if  'rnn.1' not in k }
    modelDict.update(preModelDict)##更新权重
    model.load_state_dict(modelDict)##加载预训练模型权重
    return model

nclass =5530
ocrModelPath = 'ocr.pth'
model =pre_model(nclass,ocrModelPath)
##定义你自己的模型
nclass=10##字符集大小
newmodel = new_model(10,model) 

wenlihaoyu avatar Sep 15 '18 03:09 wenlihaoyu

非常感谢,难得作者这么耐心!

ZJU-PLP avatar Sep 15 '18 03:09 ZJU-PLP

如果你的模型字符集和本项目的不一致,@Aurora11111 @ZJU-PLP 根据上面的代码,调整模型进行训练即可。具体训练可以参考crnn.pytorch项目https://github.com/meijieru/crnn.pytorch.git

wenlihaoyu avatar Sep 15 '18 03:09 wenlihaoyu

是不是根据你的代码,我只需要训练我自己那部分字符就可以了?

Steverdeng avatar Sep 20 '18 01:09 Steverdeng

@Steverdeng 对

wenlihaoyu avatar Sep 22 '18 13:09 wenlihaoyu

这个模型手写体识别支持吗?能否增加手写体数据训练?

beimingmaster avatar Sep 28 '18 12:09 beimingmaster

@beimingmaster 部分手写体可以识别,我测试过了,不知道你的应用场景是怎样的,可以跑作者的模型试一下

ZJU-PLP avatar Sep 28 '18 14:09 ZJU-PLP

@beimingmaster 部分手写体可以识别,我测试过了,不知道你的应用场景是怎样的,可以跑作者的模型试一下

谢谢,手写体我看网上有中科院的数据,我打算加点训练数据试试。 看了下前面的issue,好像没有一个明确说明如何增加额外训练数据的描述文档。

beimingmaster avatar Sep 29 '18 06:09 beimingmaster

@Steverdeng @wenlihaoyu @ZJU-PLP 请问下是不是换了crnn.py成你上面的代码;然后keys.py是在你原有的字符追加那部分字符,还是只要写那部分字符;这样训练是不是会在你之前训练的模型上增加新的字符训练。谢谢

little2Rabbit avatar Sep 29 '18 11:09 little2Rabbit

请问:我把作者训练好的模型作为预训练模型,然后在自己的两百多万张数据集训练这样会导致过拟合吗,加了early stopping; 另外请教一个基础的问题,如果我先训练了5个epoch,得到一个模型,在用这个模型作为预训练模型再训练5个epoch,和直接训练10个epoch结果有差别很大吗?

xiaosi2017 avatar Oct 11 '18 02:10 xiaosi2017

自己应该怎么练自己的数据集,我已经拥有了LMDB的数据。

OKhyc avatar Dec 25 '18 11:12 OKhyc

@OKhyc reference:https://github.com/Aurora11111/crnn-train-pytorch

Aurora11111 avatar Dec 26 '18 02:12 Aurora11111

@Aurora11111 老兄 这个怎么训练,最后生成的是一个pth文件吗,是可以直接使用的吗

OKhyc avatar Dec 29 '18 08:12 OKhyc

@OKhyc 是的啊

Aurora11111 avatar Jan 09 '19 03:01 Aurora11111

如何配合这个使用呢! https://github.com/chineseocr/ocr-label

bournes avatar Jul 08 '19 10:07 bournes

@Aurora11111 你上面提供crnn非常好,我想问下,该如何训练文字检测呢?

zlszhonglongshen avatar Nov 01 '19 08:11 zlszhonglongshen

reference:https://github.com/Aurora11111/text-detection-train-yolov3

在 2019-11-01 16:16:16,"zlszhonglongshen" [email protected] 写道:

@Aurora11111 你上面提供crnn非常好,我想问下,该如何训练文字检测呢?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or unsubscribe.

Aurora11111 avatar Nov 21 '19 08:11 Aurora11111

@wenlihaoyu 我是想在您原有的pytorch ocr识别模型ocr-lstm.pth基础上,再增加自己的训练集,自己准备好了行图片与行文本,训练代码参考只有https://github.com/Aurora11111/crnn-train-pytorch吗?里面有不少坑,比如不用warp_ctc_pytorch,直接用新版本的torch.nn.CTCLoss 函数可不可以。有没有可参考的最新的训练代码,直接载入原来的模型,然后加上自己的数据集行图片与行文本,生成新的模型文件。

DominicTerry avatar Apr 28 '20 02:04 DominicTerry

@beimingmaster 部分手写体可以识别,我测试过了,不知道你的应用场景是怎样的,可以跑作者的模型试一下

谢谢,手写体我看网上有中科院的数据,我打算加点训练数据试试。 看了下前面的issue,好像没有一个明确说明如何增加额外训练数据的描述文档。

请问使用中科院的手写数据后,对手写支持的怎么样呢

Zhang-O avatar Jul 02 '20 09:07 Zhang-O