NeMo
NeMo copied to clipboard
"RuntimeError: start (4) + length (1) exceeds dimension size (4)." when running cache aware streaming inference
Hello,
I'm currently working on developing a websocket-based client-server application for live transcription. I've successfully created a client that reads an audio file from disk and sends audio chunks through the websocket, effectively simulating a live stream. On the server side, I receive each chunk and attempt to transcribe it using the logic provided in this example [1].
For transcription, I'm utilizing the "stt_en_fastconformer_hybrid_large_streaming_80ms" model. The audio file I'm using for testing purposes is encoded as 16-bit Signed Integer PCM at a sampling rate of 16 kHz, and it's in mono format. My NeMo version is v1.23.0.
However, I get this server side error (detalied below): RuntimeError: start (4) + length (1) exceeds dimension size (4).
Here is my server side code:
import asyncio
import websockets
import json
import copy
import numpy as np
import torch
import torch.nn.functional as F
# Load your model and any other dependencies here
from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE
import nemo.collections.asr as nemo_asr
from omegaconf import OmegaConf, open_dict
from copy import deepcopy
# Load your ASR model
asr_model = None # Initialize to None
preprocessor = None
# Define global variables for caching
cache_last_channel = None
cache_last_time = None
cache_last_channel_len = None
previous_hypotheses = None
pred_out_stream = None
step_num = 0
pre_encode_cache_size = 0
cache_pre_encode = None
def extract_transcriptions(hyps):
"""
The transcribed_texts returned by CTC and RNNT models are different.
This method would extract and return the text section of the hypothesis.
"""
if isinstance(hyps[0], Hypothesis):
transcriptions = []
for hyp in hyps:
transcriptions.append(hyp.text)
else:
transcriptions = hyps
return transcriptions
def init_preprocessor(asr_model):
cfg = copy.deepcopy(asr_model._cfg)
OmegaConf.set_struct(cfg.preprocessor, False)
# some changes for streaming scenario
cfg.preprocessor.dither = 0.0
cfg.preprocessor.pad_to = 0
cfg.preprocessor.normalize = "None"
preprocessor = EncDecCTCModelBPE.from_config_dict(cfg.preprocessor)
preprocessor.to(asr_model.device)
return preprocessor
def preprocess_audio(audio, asr_model):
global preprocessor
device = asr_model.device
# doing audio preprocessing
audio_signal = torch.from_numpy(audio).unsqueeze_(0).to(device)
audio_signal_len = torch.Tensor([audio.shape[0]]).to(device)
processed_signal, processed_signal_length = preprocessor(
input_signal=audio_signal, length=audio_signal_len
)
# Print shape of processed_signal
print("Shape of processed_signal:", processed_signal.shape)
return processed_signal, processed_signal_length
def transcribe_chunk(new_chunk):
global cache_last_channel, cache_last_time, cache_last_channel_len
global previous_hypotheses, pred_out_stream, step_num
global cache_pre_encode
# new_chunk is provided as np.int16, so we convert it to np.float32
# as that is what our ASR models expect
audio_data = new_chunk.astype(np.float32)
audio_data = audio_data / 32768.0
# get mel-spectrogram signal & length
processed_signal, processed_signal_length = preprocess_audio(audio_data, asr_model)
# prepend with cache_pre_encode
processed_signal = torch.cat([cache_pre_encode, processed_signal], dim=-1)
processed_signal_length += cache_pre_encode.shape[1]
# save cache for next time
cache_pre_encode = processed_signal[:, :, -pre_encode_cache_size:]
with torch.no_grad():
(
pred_out_stream,
transcribed_texts,
cache_last_channel,
cache_last_time,
cache_last_channel_len,
previous_hypotheses,
) = asr_model.conformer_stream_step(
processed_signal=processed_signal,
processed_signal_length=processed_signal_length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
keep_all_outputs=False,
previous_hypotheses=previous_hypotheses,
previous_pred_out=pred_out_stream,
drop_extra_pre_encoded=None,
return_transcription=True,
)
final_streaming_tran = extract_transcriptions(transcribed_texts)
step_num += 1
# Print shape of x before narrow operation
print("Shape of x before narrow:", processed_signal.shape)
return final_streaming_tran[0]
async def audio_consumer(websocket, path):
try:
while True:
audio_chunk_str = await websocket.recv()
audio_chunk = np.frombuffer(audio_chunk_str, dtype=np.int16)
transcription = transcribe_chunk(audio_chunk)
await websocket.send(json.dumps({"transcription": transcription}))
except websockets.exceptions.ConnectionClosed:
pass
async def start_server():
global asr_model
# Load your ASR model
# Replace 'path_to_your_model' with the actual path to your model
model_path = "models/stt_en_fastconformer_hybrid_large_streaming_80ms.nemo"
asr_model = nemo_asr.models.EncDecRNNTBPEModel.restore_from(model_path)
lookahead_size = 80
decoder_type = "ctc"
# specify ENCODER_STEP_LENGTH (which is 80 ms for FastConformer models)
ENCODER_STEP_LENGTH = 80 # ms
# update att_context_size
left_context_size = asr_model.encoder.att_context_size[0]
asr_model.encoder.set_default_att_context_size([left_context_size, int(lookahead_size / ENCODER_STEP_LENGTH)])
asr_model.encoder.setup_streaming_params()
# make sure we use the specified decoder_type
asr_model.change_decoding_strategy(decoder_type=decoder_type)
# make sure the model's decoding strategy is optimal
decoding_cfg = asr_model.cfg.decoding
with open_dict(decoding_cfg):
# save time by doing greedy decoding and not trying to record the alignments
decoding_cfg.strategy = "greedy"
decoding_cfg.preserve_alignments = False
if hasattr(asr_model, 'joint'): # if an RNNT model
# restrict max_symbols to make sure not stuck in infinite loop
decoding_cfg.greedy.max_symbols = 10
# sensible default parameter, but not necessary since batch size is 1
decoding_cfg.fused_batch_size = -1
asr_model.change_decoding_strategy(decoding_cfg)
# set model to eval mode
asr_model.eval()
# get parameters to use as the initial cache state
cache_last_channel, cache_last_time, cache_last_channel_len = asr_model.encoder.get_initial_cache_state(
batch_size=1
)
global preprocessor
preprocessor = init_preprocessor(asr_model)
# Initialize global variables for caching
global pre_encode_cache_size, cache_pre_encode
pre_encode_cache_size = asr_model.encoder.streaming_cfg.pre_encode_cache_size[1]
cache_pre_encode = torch.zeros((1, asr_model.cfg.preprocessor.features, pre_encode_cache_size), device=asr_model.device)
async with websockets.serve(audio_consumer, "localhost", 8765):
await asyncio.Future() # Run server forever
asyncio.run(start_server())
And here is my client side code:
import asyncio
import websockets
import json
import numpy as np
import pyaudio
import struct
# Set the sample rate and chunk size
SAMPLE_RATE = 16000
lookahead_size = 80
ENCODER_STEP_LENGTH = 80
chunk_size_ms = lookahead_size + ENCODER_STEP_LENGTH
chunk_size_samples = int(SAMPLE_RATE * chunk_size_ms / 1000) - 1
async def send_audio_stream(file_path, websocket):
with open(file_path, 'rb') as audio_file:
while True:
audio_chunk = audio_file.read(chunk_size_samples * 2) # 2 bytes per sample for 16-bit audio
if not audio_chunk:
break
await websocket.send(audio_chunk)
response = await websocket.recv()
transcription = json.loads(response)
print("Transcription:", transcription['transcription'])
async def main():
# WebSocket server address
uri = "ws://localhost:8765"
# File path of the audio to stream
audio_file_path = "test-audio/my_audio_file.wav"
async with websockets.connect(uri) as websocket:
await send_audio_stream(audio_file_path, websocket)
# Run the main function
asyncio.run(main())
Here is the error I get on server side:
connection handler failed
Traceback (most recent call last):
File "/usr/local/lib/python3.10/site-packages/websockets/legacy/server.py", line 236, in handler
await self.ws_handler(self)
File "/usr/local/lib/python3.10/site-packages/websockets/legacy/server.py", line 1175, in _ws_handler
return await cast(
File "/home/apps/ASR/server_streaming.py", line 142, in audio_consumer
transcription = transcribe_chunk(audio_chunk)
File "/home/apps/ASR/server_streaming.py", line 112, in transcribe_chunk
) = asr_model.conformer_stream_step(
File "/home/apps/ASR/nemo/collections/asr/parts/mixins/mixins.py", line 676, in conformer_stream_step
best_hyp, all_hyp_or_transcribed_texts = self.decoding.rnnt_decoder_predictions_tensor(
File "/home/apps/ASR/nemo/collections/asr/parts/submodules/rnnt_decoding.py", line 455, in rnnt_decoder_predictions_tensor
hypotheses_list = self.decoding(
File "/home/apps/ASR/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py", line 180, in __call__
return self.forward(*args, **kwargs)
File "/home/apps/ASR/nemo/core/classes/common.py", line 1098, in __call__
outputs = wrapped(*args, **kwargs)
File "/home/apps/ASR/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py", line 388, in forward
hypothesis = self._greedy_decode(inseq, logitlen, partial_hypotheses=partial_hypothesis)
File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/apps/ASR/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py", line 431, in _greedy_decode
f = x.narrow(dim=0, start=time_idx, length=1)
RuntimeError: start (4) + length (1) exceeds dimension size (4).
Any ideas please? Thank you!