GPLinker_pytorch
GPLinker_pytorch copied to clipboard
GPLinker_pytorch
GPLinker_pytorch
GPLinker_pytorch
介绍
这是pytorch版本的GPLinker代码以及TPLinker_Plus代码。
GPLinker主要参考了苏神博客和他的keras版本代码TPLinker_Plus主要参考了原版代码- 其中
TPLinker_Plus代码在模型部分可能有点区别。
更新
- 2022/03/03 添加
tplinker_plus+bert-base-chinese权重在duie_v1上的结果。添加duee_v1任务的训练代码,请查看duee_v1目录。 - 2022/03/01 添加
tplinker_plus+hfl/chinese-roberta-wwm-ext权重在duie_v1上的结果。 - 2022/02/25 现已在Dev分支更新最新的huggingface全家桶版本的代码,main分支是之前旧的代码(执行效率慢)
结果
Tips: 在RTX3090,20epoch的条件下,gplinker需要训练5-6h,tplinker_plus则需要训练16-17h。
| dataset | method | pretrained_model_name_or_path | f1 | precision | recall |
|---|---|---|---|---|---|
| duie_v1 | gplinker | hfl/chinese-roberta-wwm-ext | 0.8214065255731926 | 0.8250077498782166 | 0.8178366038895478 |
| duie_v1 | gplinker | bert-base-chinese | 0.8198087178424598 | 0.8146470447994109 | 0.8250362175688137 |
| duie_v1 | tplinker_plus | hfl/chinese-roberta-wwm-ext | 0.8256425523469291 | 0.8295114656031908 | 0.8218095614381671 |
| duie_v1 | tplinker_plus | bert-base-chinese | 0.8216261688290682 | 0.8076458240569943 | 0.8360990385881737 |
Tensorboard日志
gplinker训练日志
tplinker_plus训练日志
依赖
所需的依赖如下:
- fastcore==1.3.29
- datasets==1.18.3
- transformers>=4.16.2
- accelerate==0.5.1
- chinesebert==0.2.1
安装依赖requirements.txt
pip install -r requirements.txt
准备数据
从 http://ai.baidu.com/broad/download?dataset=sked 下载数据。
将train_data.json和dev_data.json压缩成spo.zip文件,并且放入data文件夹。
当前data/spo.zip文件是本人提供精简后的数据集,其中train_data.json只有2000条数据,dev_data.json只有200条数据。
运行
accelerate launch train.py \
--model_type bert \
--pretrained_model_name_or_path bert-base-chinese \
--method gplinker \
--logging_steps 200 \
--num_train_epochs 20 \
--learning_rate 3e-5 \
--num_warmup_steps_or_radios 0.1 \
--gradient_accumulation_steps 1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 32 \
--seed 42 \
--save_steps 10804 \
--output_dir ./outputs \
--max_length 128 \
--topk 1 \
--num_workers 6
其中使用到参数介绍如下:
model_type: 表示模型架构类型,像bert-base-chinese、hfl/chinese-roberta-wwm-ext模型都是基于bert架构,junnyu/roformer_chinese_char_base是基于roformer架构,可选择["bert", "roformer", "chinesebert"]。pretrained_model_name_or_path: 表示加载的预训练模型权重,可以是本地目录,也可以是huggingface.co的路径。method: 表示使用的方法, 可选择["gplinker", "tplinker_plus"]logging_steps: 日志打印的间隔,默认为200。num_train_epochs: 训练轮数,默认为20。learning_rate: 学习率,默认为3e-5。num_warmup_steps_or_radios:warmup步数或者比率,当为浮点类型时候表示的是radio,当为整型时候表示的是step,默认为0.1。gradient_accumulation_steps: 梯度累计的步数,默认为1。per_device_train_batch_size: 训练的batch_size,默认为16。per_device_eval_batch_size: 评估的batch_size,默认为32。seed: 随机种子,以便于复现,默认为42。save_steps: 保存步数,每隔多少步保存模型。output_dir: 模型输出路径。max_length: 句子的最大长度,当大于这个长度时候,tokenizer会进行截断处理。topk: 保存topk个数模型,默认为1。num_workers:dataloader的num_workers参数,linux系统下发现GPU使用率不高的时候可以尝试设置这个参数大于0,而windows下最好设置为0,不然会报错。use_efficient: 是否使用EfficientGlobalPointer,默认为False。
Reference
- 苏剑林. (Jan. 30, 2022). 《GPLinker:基于GlobalPointer的实体关系联合抽取 》[Blog post]. Retrieved from https://kexue.fm/archives/8888
- https://github.com/bojone/GPLinker
- https://github.com/bojone/bert4keras/tree/master/examples/task_relation_extraction_gplinker.py
- https://github.com/131250208/TPlinker-joint-extraction/tree/master/tplinker_plus