icefall
icefall copied to clipboard
Early Stopping of Token Generation in Streaming Model Training
Hi Next-gen Kaldi team,
Thank you once again for your continuous support and patience with our Japanese ASR recipe and model developments.
We're currently training the streaming model based on our existing recipe, ReazonSpeech
. Despite experimenting with both the regular zipformer
and zipformer-L
across different datasets (100h, 1000h, and 5000h), we've encountered a consistent issue where the output tends to generate only the first few tokens.
Current environment:
$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Feb__7_19:32:13_PST_2023
Cuda compilation tools, release 12.1, V12.1.66
Build cuda_12.1.r12.1/compiler.32415258_0
$ python3 -c "import torch; print(torch.__version__)"
2.3.1+cu121
$ python3 -c "import torchaudio; print(torchaudio.__version__)"
2.3.1+cu121
$ python3 -m k2.version
Collecting environment information...
k2 version: 1.24.4
Build type: Release
Git SHA1: 8f976a1e1407e330e2a233d68f81b1eb5269fdaa
Git date: Thu Jun 6 02:13:08 2024
Cuda used to build k2: 12.1
cuDNN used to build k2:
Python version used to build k2: 3.10
OS used to build k2: CentOS Linux release 7.9.2009 (Core)
CMake version: 3.29.3
GCC version: 9.3.1
CMAKE_CUDA_FLAGS: -Wno-deprecated-gpu-targets -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_50,code=sm_50 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_60,code=sm_60 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_61,code=sm_61 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_70,code=sm_70 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_75,code=sm_75 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_80,code=sm_80 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_86,code=sm_86 -DONNX_NAMESPACE=onnx_c2 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_89,code=sm_89 -gencode arch=compute_90,code=sm_90 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -D_GLIBCXX_USE_CXX11_ABI=0 --compiler-options -Wall --compiler-options -Wno-strict-overflow --compiler-options -Wno-unknown-pragmas
CMAKE_CXX_FLAGS: -D_GLIBCXX_USE_CXX11_ABI=0 -Wno-unused-variable -Wno-strict-overflow
PyTorch version used to build k2: 2.3.1+cu121
PyTorch is using Cuda: 12.1
NVTX enabled: True
With CUDA: True
Disable debug: True
Sync kernels : False
Disable checks: False
Max cpu memory allocate: 214748364800 bytes (or 200.0 GB)
k2 abort: False
__file__: /usr/local/lib/python3.10/dist-packages/k2/version/version.py
_k2.__file__: /usr/local/lib/python3.10/dist-packages/_k2.cpython-310-x86_64-linux-gnu.so
$ python3 -c "import lhotse; print(lhotse.__version__)"
1.26.0.dev+git.bd12d5d.clean
Our commands and results:
Training command (regular zipformer):
./zipformer/train.py \
--world-size 8 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp-causal \
--causal 1 \
--lang data/lang_char \
--max-duration 1600
Decoding command:
./zipformer/streaming_decode.py \
--epoch 30 \
--avg 15 \
--causal 1 \
--chunk-size 32 \
--left-context-frames 128 \
--exp-dir zipformer/exp-causal \
--lang data/lang_char
Some results from errs-test-greedy_search-epoch-30-avg-15-chunk-32-left-context-128-use-averaged-model.txt
:
1000-0: (ライブ映像です菅総理のコメントがこれから発表されます->そ)
1001-1: (日経平均株価の午前の終値二万八千八十一円五十五銭と七十四円六十六銭->日)
1002-2: (来年の大統領選挙を控える中で四件目の起訴を受けたわけですが今回も相変わらず選挙妨害だなどと無実を主張しています->ラ)
1003-3: (膿の除去や歯周病の原因となる歯石の除去などのケアを続けたのです->こ)
1004-4: (まずは東京都心のお天気の変化から見てみましょう->ま)
1005-5: (ご準備お願いいたします->でも)
1006-6: (だって上いったら筋見えるよ->だって)
1007-7: (ロシアの潜水艦が日本海でミサイル発射の演習を行いました->ここか)
1008-8: (まあまあまあでもさこれもほらあのトカゲが急に敵におそわれたときしっぽちょん切ってにげるみてえな感じだから->ま)
We also exported this model and tested with sherpa-onnx.
exporting command:
./zipformer/export-onnx-streaming.py \
--tokens data/lang_char/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir zipformer/exp-causal \
--causal True \
--chunk-size 16 \
--left-context-frames 128 \
--fp16 True
Decoding with Python API examples:
./python-api-examples/online-decode-files.py \
--tokens=./pretrained-models/k2-streaming/tokens.txt \
--num-threads=4 \
--encoder=./pretrained-models/k2-streaming/1000h/encoder-epoch-99-avg-1-chunk-16-left-128.onnx \
--decoder=./pretrained-models/k2-streaming/1000h/decoder-epoch-99-avg-1-chunk-16-left-128.onnx \
--joiner=./pretrained-models/k2-streaming/1000h/joiner-epoch-99-avg-1-chunk-16-left-128.onnx \
./pretrained-models/k2-streaming/test_wavs/0.wav \
./pretrained-models/k2-streaming/test_wavs/1.wav
Started!
Done!
./pretrained-models/k2-streaming/test_wavs/0.wav
ら
----------
./pretrained-models/k2-streaming/test_wavs/1.wav
屯
----------
num_threads: 4
decoding_method: greedy_search
Wave duration: 23.340 s
Elapsed time: 1.159 s
Real time factor (RTF): 1.159/23.340 = 0.050
The outputs we're seeing from both the streaming_decode.py
and the sherpa-onnx deployed models are truncated early in the speech, leading to significantly shortened or incomplete transcriptions.
We would greatly appreciate any insights or suggestions on how to address these early stopping issues in token generation. We will also open-source this streaming model as soon as we resolve these challenges.
Thank you!