fastNLP icon indicating copy to clipboard operation
fastNLP copied to clipboard

pth模型转换到pt模型时,在crf模块报错

Open wanilyer opened this issue 4 years ago • 7 comments

疑似使用crf模块会遇到bug,希望大神帮忙看看,应该怎么解决,谢谢!

Describe the bug

描述

  • 1、使用bert+crf模块训练一个模型,保存为pth

torch.save(model, model_path)

  • 2、试图将该模型转换到pt结构时,在crf模块处报错
model = torch.load("{}/{}".format(model_path, model_file))
traced_script_module = torch.jit.trace(model, args)
traced_script_module.save("{}/test.pt".format(model_path))

报错信息:

Traceback (most recent call last):
  File "D:/deeplearning-NLP/flat_lattice_transformer/V1/exporter.py", line 154, in <module>
    export_torchscript(model_path=model_path, model_file=model_file)
  File "D:/deeplearning-NLP/flat_lattice_transformer/V1/exporter.py", line 125, in export_torchscript
    traced_script_module = torch.jit.trace(model, args)
  File "C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\jit\__init__.py", line 875, in trace
    check_tolerance, _force_outplace, _module_class)
  File "C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\jit\__init__.py", line 1037, in trace_module
    check_tolerance, _force_outplace, True, _module_class)
  File "C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\autograd\grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\jit\__init__.py", line 675, in _check_trace
    raise TracingCheckError(*diag_info)
torch.jit.TracingCheckError: Tracing failed sanity checks!
ERROR: Tensor-valued Constant nodes differed in value across invocations. This often indicates that the tracer has encountered untraceable code.
	Node:
		%393 : Tensor = prim::Constant[value=<Tensor>]() # C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\fastNLP\modules\decoder\crf.py:295:0
	Source Location:
		C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\fastNLP\modules\decoder\crf.py(295): viterbi_decode
		D:\deeplearning-NLP\flat_lattice_transformer\V1\models.py(524): forward
		C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\nn\modules\module.py(534): _slow_forward
		C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\nn\modules\module.py(548): __call__
		C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\jit\__init__.py(1027): trace_module
		C:\ProgramData\Anaconda3\envs\flat_lattice_transformer\lib\site-packages\torch\jit\__init__.py(875): trace
		D:/deeplearning-NLP/flat_lattice_transformer/V1/exporter.py(125): export_torchscript
		D:/deeplearning-NLP/flat_lattice_transformer/V1/exporter.py(154): <module>
	Comparison exception: 	Subtraction, the `-` operator, with a bool tensor is not supported. If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.

wanilyer avatar Jul 13 '21 09:07 wanilyer

你使用的是哪个版本的fastNLP?是github上的吗,还是直接pip install fastNLP安装的?

yhcc avatar Jul 13 '21 15:07 yhcc

根据报错信息,我怀疑可能是下面这行有问题https://github.com/fastnlp/fastNLP/blob/b127963f213226dc796720193965d86dface07d5/fastNLP/modules/decoder/crf.py#L307 修改成, flip_mask = torch.logical_not(mask)应该就可以了。这个错误的根源应该类似于https://github.com/pytorch/pytorch/issues/33692 这个,就是booltensor不支持一些运算,导致torchscript在检查转换前和转换后的tensor的时候,会出现结果对不上的问题。

yhcc avatar Jul 13 '21 15:07 yhcc

感谢您的回复

我用都是0.5.0的版本,直接pip install fastNLP安装的

报错的应该是下面这行 https://github.com/fastnlp/fastNLP/blob/ba0269f23e72e446ddfdbda32edccfb694b76b4f/fastNLP/modules/decoder/crf.py#L295

wanilyer avatar Jul 14 '21 01:07 wanilyer

soga,那把这一行修改成mask = mask.transpose(0, 1).data.to(torch.bool)应该就可以。就是所有的bool类型的数据都不要让它做任何比较或者运算。

yhcc avatar Jul 14 '21 02:07 yhcc

好像不行,还是报同样的错哈。

另外还有下面这行,会报错

https://github.com/fastnlp/fastNLP/blob/ba0269f23e72e446ddfdbda32edccfb694b76b4f/fastNLP/embeddings/embedding.py#L208

TypeError: tuple expected at most 1 arguments, got 2

我改成 return torch.Size((self.num_embedding, self._embed_size)) 这样就没问题了

wanilyer avatar Jul 14 '21 02:07 wanilyer

我把这一行注释了之后,后面又发现了两个地方也有这个问题,我没在往下注释了

https://github.com/fastnlp/fastNLP/blob/ba0269f23e72e446ddfdbda32edccfb694b76b4f/fastNLP/modules/decoder/crf.py#L314-L315

https://github.com/fastnlp/fastNLP/blob/ba0269f23e72e446ddfdbda32edccfb694b76b4f/fastNLP/modules/decoder/crf.py#L323

wanilyer avatar Jul 14 '21 03:07 wanilyer

确实,感觉fastNLP这些代码应该不太适合转成jit,有太多的逻辑判断了,以及cpu操作了。应该需要你自己对照着稍微改一下。因为jit的话,好像出现constant就不行,但是fastNLP中又大量使用了constant来表示一些数字。

yhcc avatar Jul 15 '21 14:07 yhcc