pth模型转换到pt模型时,在crf模块报错
疑似使用crf模块会遇到bug,希望大神帮忙看看,应该怎么解决,谢谢!
Describe the bug
描述
- 1、使用bert+crf模块训练一个模型,保存为pth
torch.save(model, model_path)
- 2、试图将该模型转换到pt结构时,在crf模块处报错
model = torch.load("{}/{}".format(model_path, model_file))
traced_script_module = torch.jit.trace(model, args)
traced_script_module.save("{}/test.pt".format(model_path))
报错信息:
Traceback (most recent call last):
File "D:/deeplearning-NLP/flat_lattice_transformer/V1/exporter.py", line 154, in <module>
export_torchscript(model_path=model_path, model_file=model_file)
File "D:/deeplearning-NLP/flat_lattice_transformer/V1/exporter.py", line 125, in export_torchscript
traced_script_module = torch.jit.trace(model, args)
File "C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\jit\__init__.py", line 875, in trace
check_tolerance, _force_outplace, _module_class)
File "C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\jit\__init__.py", line 1037, in trace_module
check_tolerance, _force_outplace, True, _module_class)
File "C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\autograd\grad_mode.py", line 15, in decorate_context
return func(*args, **kwargs)
File "C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\jit\__init__.py", line 675, in _check_trace
raise TracingCheckError(*diag_info)
torch.jit.TracingCheckError: Tracing failed sanity checks!
ERROR: Tensor-valued Constant nodes differed in value across invocations. This often indicates that the tracer has encountered untraceable code.
Node:
%393 : Tensor = prim::Constant[value=<Tensor>]() # C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\fastNLP\modules\decoder\crf.py:295:0
Source Location:
C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\fastNLP\modules\decoder\crf.py(295): viterbi_decode
D:\deeplearning-NLP\flat_lattice_transformer\V1\models.py(524): forward
C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\nn\modules\module.py(534): _slow_forward
C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\nn\modules\module.py(548): __call__
C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\jit\__init__.py(1027): trace_module
C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\jit\__init__.py(875): trace
D:/deeplearning-NLP/flat_lattice_transformer/V1/exporter.py(125): export_torchscript
D:/deeplearning-NLP/flat_lattice_transformer/V1/exporter.py(154): <module>
Comparison exception: Subtraction, the `-` operator, with a bool tensor is not supported. If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.
你使用的是哪个版本的fastNLP?是github上的吗,还是直接pip install fastNLP安装的?
根据报错信息,我怀疑可能是下面这行有问题https://github.com/fastnlp/fastNLP/blob/b127963f213226dc796720193965d86dface07d5/fastNLP/modules/decoder/crf.py#L307
修改成,
flip_mask = torch.logical_not(mask)应该就可以了。这个错误的根源应该类似于https://github.com/pytorch/pytorch/issues/33692 这个,就是booltensor不支持一些运算,导致torchscript在检查转换前和转换后的tensor的时候,会出现结果对不上的问题。
感谢您的回复
我用都是0.5.0的版本,直接pip install fastNLP安装的
报错的应该是下面这行 https://github.com/fastnlp/fastNLP/blob/ba0269f23e72e446ddfdbda32edccfb694b76b4f/fastNLP/modules/decoder/crf.py#L295
soga,那把这一行修改成mask = mask.transpose(0, 1).data.to(torch.bool)应该就可以。就是所有的bool类型的数据都不要让它做任何比较或者运算。
好像不行,还是报同样的错哈。
另外还有下面这行,会报错
https://github.com/fastnlp/fastNLP/blob/ba0269f23e72e446ddfdbda32edccfb694b76b4f/fastNLP/embeddings/embedding.py#L208
TypeError: tuple expected at most 1 arguments, got 2
我改成
return torch.Size((self.num_embedding, self._embed_size))
这样就没问题了
我把这一行注释了之后,后面又发现了两个地方也有这个问题,我没在往下注释了
https://github.com/fastnlp/fastNLP/blob/ba0269f23e72e446ddfdbda32edccfb694b76b4f/fastNLP/modules/decoder/crf.py#L314-L315
https://github.com/fastnlp/fastNLP/blob/ba0269f23e72e446ddfdbda32edccfb694b76b4f/fastNLP/modules/decoder/crf.py#L323
确实,感觉fastNLP这些代码应该不太适合转成jit,有太多的逻辑判断了,以及cpu操作了。应该需要你自己对照着稍微改一下。因为jit的话,好像出现constant就不行,但是fastNLP中又大量使用了constant来表示一些数字。