PaddleSpeech
PaddleSpeech copied to clipboard
使用conformer_wenetspeech模型测试时发生Skip loading for encoder.embed.out.0.weight.
Hello,我在使用conformer_wenetspeech模型测试时发生Skip loading for encoder.embed.out.0.weight.的报错, 猜测是某一层的权重维度与模型不匹配,研究了很长时间都没找到错误原因,请帮忙指点下啊,感谢。
报错和日志如下,
报错:
2022-09-09 16:17:06.267 | INFO | paddlespeech.s2t.exps.u2.model:setup_model:263 - Setup model!
D:\Program\miniconda3\envs\paddle\lib\site-packages\paddle\fluid\dygraph\layers.py:1492: UserWarning: Skip loading for encoder.embed.out.0.weight. encoder.embed.out.0.weight receives a shape [9728, 512], but the expected shape is [19968, 512].
warnings.warn(("Skip loading for {}. ".format(key) + str(err)))
2022-09-09 16:17:07.686 | INFO | paddlespeech.s2t.utils.checkpoint:load_parameters:117 - Rank 0: Restore model from configs/conformer_wenetspeech-zh-16k_1.0/wenetspeech.pdparams
2022-09-09 16:17:07.689 | INFO | paddlespeech.s2t.exps.u2.model:test:390 - Test Total Examples: 5
D:\Program\miniconda3\envs\paddle\lib\site-packages\paddle\fluid\dygraph\math_op_patch.py:278: UserWarning: The dtype of left and right variables are not the same, left dtype is paddle.int64, but right dtype is paddle.int32, the right dtype will convert to paddle.int64
format(lhs_dtype, rhs_dtype, lhs_dtype))
Traceback (most recent call last):
File "d:/Code/PADDLE/PaddleSpeech-develop/demos/speech_recognition/eval.py", line 72, in <module>
2022-09-09 16:17:08.058 exp.run_test()
| File "D:\Program\miniconda3\envs\paddle\lib\site-packages\paddlespeech\s2t\training\trainer.py", line 365, in run_test
INFO self.test()
File "D:\Program\miniconda3\envs\paddle\lib\site-packages\paddlespeech\s2t\utils\mp_tools.py", line 27, in wrapper
| paddlespeech.s2t.training.timerresult = func(*args, **kwargs)
: File "D:\Program\miniconda3\envs\paddle\lib\site-packages\decorator.py", line 232, in fun
__exit__: return caller(func, *(extras + args), **kw)
44 File "D:\Program\miniconda3\envs\paddle\lib\site-packages\paddle\fluid\dygraph\base.py", line 354, in _decorate_function
- return func(*args, **kwargs)
Test/Decode Done: 0:00:01.790112 File "D:\Program\miniconda3\envs\paddle\lib\site-packages\paddlespeech\s2t\exps\u2\model.py", line 399, in test
metrics = self.compute_metrics(*batch, fout=fout)
File "D:\Program\miniconda3\envs\paddle\lib\site-packages\paddlespeech\s2t\exps\u2\model.py", line 353, in compute_metrics
simulate_streaming=decode_config.simulate_streaming)
File "D:\Program\miniconda3\envs\paddle\lib\site-packages\decorator.py", line 232, in fun
return caller(func, *(extras + args), **kw)
File "D:\Program\miniconda3\envs\paddle\lib\site-packages\paddle\fluid\dygraph\base.py", line 354, in _decorate_function
return func(*args, **kwargs)
File "D:\Program\miniconda3\envs\paddle\lib\site-packages\paddlespeech\s2t\models\u2\u2.py", line 736, in decode
simulate_streaming=simulate_streaming)
File "D:\Program\miniconda3\envs\paddle\lib\site-packages\paddlespeech\s2t\models\u2\u2.py", line 258, in recognize
simulate_streaming) # (B, maxlen, encoder_dim)
File "D:\Program\miniconda3\envs\paddle\lib\site-packages\paddlespeech\s2t\models\u2\u2.py", line 221, in _forward_encoder
num_decoding_left_chunks=num_decoding_left_chunks
File "D:\Program\miniconda3\envs\paddle\lib\site-packages\paddle\fluid\dygraph\layers.py", line 930, in __call__
return self._dygraph_call_func(*inputs, **kwargs)
File "D:\Program\miniconda3\envs\paddle\lib\site-packages\paddle\fluid\dygraph\layers.py", line 915, in _dygraph_call_func
outputs = self.forward(*inputs, **kwargs)
File "D:\Program\miniconda3\envs\paddle\lib\site-packages\paddlespeech\s2t\modules\encoder.py", line 168, in forward
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
File "D:\Program\miniconda3\envs\paddle\lib\site-packages\paddle\fluid\dygraph\layers.py", line 930, in __call__
return self._dygraph_call_func(*inputs, **kwargs)
File "D:\Program\miniconda3\envs\paddle\lib\site-packages\paddle\fluid\dygraph\layers.py", line 915, in _dygraph_call_func
outputs = self.forward(*inputs, **kwargs)
File "D:\Program\miniconda3\envs\paddle\lib\site-packages\paddlespeech\s2t\modules\subsampling.py", line 143, in forward
x = self.out(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f]))
File "D:\Program\miniconda3\envs\paddle\lib\site-packages\paddle\fluid\dygraph\layers.py", line 930, in __call__
return self._dygraph_call_func(*inputs, **kwargs)
File "D:\Program\miniconda3\envs\paddle\lib\site-packages\paddle\fluid\dygraph\layers.py", line 915, in _dygraph_call_func
__
return self._dygraph_call_func(*inputs, **kwargs)
File "D:\Program\miniconda3\envs\paddle\lib\site-packages\paddle\fluid\dygraph\layers.py", line 915, in _dygraph_call_func
outputs = self.forward(*inputs, **kwargs)
File "D:\Program\miniconda3\envs\paddle\lib\site-packages\paddle\nn\layer\common.py", line 172, in forward
x=input, weight=self.weight, bias=self.bias, name=self.name)
File "D:\Program\miniconda3\envs\paddle\lib\site-packages\paddle\nn\functional\common.py", line 1542, in linear
False)
ValueError: (InvalidArgument) Input(Y) has error dim.Y'dims[0] must be equal to 9728But received Y'dims[0] is 19968
[Hint: Expected y_dims[y_ndim - 2] == K, but received y_dims[y_ndim - 2]:19968 != K:9728.] (at C:\home\workspace\Paddle_release\paddle/phi/kernels/impl/matmul_kernel_impl.h:315)
[operator < matmul_v2 > error]
日志: py_eval.LAPTOP-CD0ILM5K.lichuan.2022-09-09_16-16-08_948182.log
各位中秋快乐啊! :blush:
模型定义中的encoder.embed.out.0.weight与pre-trained model中的不同。 请检查模型的配置文件是否和pre-trained model一致以及 paddlespeech/s2t/modules/subsampling.py 中相关参数
感谢回复。模型文件model.yaml和权重文件wenetspeech.pdparams是从同一个的链接的压缩包下载下来的,应该是匹配的,model.yaml中有一个参数 input_layer
是可选参数,最终会传入paddlespeech/s2t/modules/subsampling.py,试了 conv2d, conv2d6 and conv2d8 都是一样的报错。
有可能是系统的问题吗,我目前是在Win上测试的,单个语音文件的预测是可以跑通的。我也试一下Ubuntu下是否有报错。
模型定义中的encoder.embed.out.0.weight与pre-trained model中的不同。 请检查模型的配置文件是否和pre-trained model一致以及 paddlespeech/s2t/modules/subsampling.py 中相关参数
在linux上测试吧,我们已ubuntu16的为准。
在linux上测试吧,我们已ubuntu16的为准。
好的,我测试下再来反馈,感谢回复。
在Ubuntu上测试了一下还是一样的报错,
环境如下:
* Ubuntu 20.04
* python 3.7
* paddlepaddle 2.3.2
* CUDA 10.2
* cuDNN 7.6
但是相同的环境,deepspeech2 的预测、模型导出测试都可以跑通。
请尝试一下develop版本是否会存在这个问题?