NeMo icon indicating copy to clipboard operation
NeMo copied to clipboard

"RuntimeError: start (4) + length (1) exceeds dimension size (4)." when running cache aware streaming inference

Open lucgeo opened this issue 9 months ago • 0 comments

Hello,

I'm currently working on developing a websocket-based client-server application for live transcription. I've successfully created a client that reads an audio file from disk and sends audio chunks through the websocket, effectively simulating a live stream. On the server side, I receive each chunk and attempt to transcribe it using the logic provided in this example [1].

For transcription, I'm utilizing the "stt_en_fastconformer_hybrid_large_streaming_80ms" model. The audio file I'm using for testing purposes is encoded as 16-bit Signed Integer PCM at a sampling rate of 16 kHz, and it's in mono format. My NeMo version is v1.23.0.

However, I get this server side error (detalied below): RuntimeError: start (4) + length (1) exceeds dimension size (4).

Here is my server side code:

import asyncio
import websockets
import json
import copy
import numpy as np
import torch
import torch.nn.functional as F

# Load your model and any other dependencies here
from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE
import nemo.collections.asr as nemo_asr
from omegaconf import OmegaConf, open_dict
from copy import deepcopy

# Load your ASR model
asr_model = None  # Initialize to None
preprocessor = None

# Define global variables for caching
cache_last_channel = None
cache_last_time = None
cache_last_channel_len = None
previous_hypotheses = None
pred_out_stream = None
step_num = 0
pre_encode_cache_size = 0
cache_pre_encode = None



def extract_transcriptions(hyps):
    """
        The transcribed_texts returned by CTC and RNNT models are different.
        This method would extract and return the text section of the hypothesis.
    """
    if isinstance(hyps[0], Hypothesis):
        transcriptions = []
        for hyp in hyps:
            transcriptions.append(hyp.text)
    else:
        transcriptions = hyps
    return transcriptions



def init_preprocessor(asr_model):
    cfg = copy.deepcopy(asr_model._cfg)
    OmegaConf.set_struct(cfg.preprocessor, False)

    # some changes for streaming scenario
    cfg.preprocessor.dither = 0.0
    cfg.preprocessor.pad_to = 0
    cfg.preprocessor.normalize = "None"

    preprocessor = EncDecCTCModelBPE.from_config_dict(cfg.preprocessor)
    preprocessor.to(asr_model.device)

    return preprocessor




def preprocess_audio(audio, asr_model):

    global preprocessor
    device = asr_model.device

    # doing audio preprocessing
    audio_signal = torch.from_numpy(audio).unsqueeze_(0).to(device)
    audio_signal_len = torch.Tensor([audio.shape[0]]).to(device)
    processed_signal, processed_signal_length = preprocessor(
        input_signal=audio_signal, length=audio_signal_len
    )

    # Print shape of processed_signal
    print("Shape of processed_signal:", processed_signal.shape)
    return processed_signal, processed_signal_length





def transcribe_chunk(new_chunk):
    
    global cache_last_channel, cache_last_time, cache_last_channel_len
    global previous_hypotheses, pred_out_stream, step_num
    global cache_pre_encode
    
    # new_chunk is provided as np.int16, so we convert it to np.float32
    # as that is what our ASR models expect
    audio_data = new_chunk.astype(np.float32)
    audio_data = audio_data / 32768.0

    # get mel-spectrogram signal & length
    processed_signal, processed_signal_length = preprocess_audio(audio_data, asr_model)
     
    # prepend with cache_pre_encode
    processed_signal = torch.cat([cache_pre_encode, processed_signal], dim=-1)
    processed_signal_length += cache_pre_encode.shape[1]
    
    # save cache for next time
    cache_pre_encode = processed_signal[:, :, -pre_encode_cache_size:]
    
    with torch.no_grad():
        (
            pred_out_stream,
            transcribed_texts,
            cache_last_channel,
            cache_last_time,
            cache_last_channel_len,
            previous_hypotheses,
        ) = asr_model.conformer_stream_step(
            processed_signal=processed_signal,
            processed_signal_length=processed_signal_length,
            cache_last_channel=cache_last_channel,
            cache_last_time=cache_last_time,
            cache_last_channel_len=cache_last_channel_len,
            keep_all_outputs=False,
            previous_hypotheses=previous_hypotheses,
            previous_pred_out=pred_out_stream,
            drop_extra_pre_encoded=None,
            return_transcription=True,
        )
    
    final_streaming_tran = extract_transcriptions(transcribed_texts)
    step_num += 1

    # Print shape of x before narrow operation
    print("Shape of x before narrow:", processed_signal.shape)
    
    return final_streaming_tran[0]





async def audio_consumer(websocket, path):
    try:
        while True:
            audio_chunk_str = await websocket.recv()
            audio_chunk = np.frombuffer(audio_chunk_str, dtype=np.int16)
            transcription = transcribe_chunk(audio_chunk)
            await websocket.send(json.dumps({"transcription": transcription}))
    except websockets.exceptions.ConnectionClosed:
        pass



async def start_server():
    global asr_model
    # Load your ASR model
    # Replace 'path_to_your_model' with the actual path to your model
    
    model_path = "models/stt_en_fastconformer_hybrid_large_streaming_80ms.nemo"
    asr_model = nemo_asr.models.EncDecRNNTBPEModel.restore_from(model_path)
    
    
    lookahead_size = 80
    decoder_type = "ctc"

    # specify ENCODER_STEP_LENGTH (which is 80 ms for FastConformer models)
    ENCODER_STEP_LENGTH = 80 # ms


    # update att_context_size
    left_context_size = asr_model.encoder.att_context_size[0]
    asr_model.encoder.set_default_att_context_size([left_context_size, int(lookahead_size / ENCODER_STEP_LENGTH)])

    asr_model.encoder.setup_streaming_params()

    # make sure we use the specified decoder_type
    asr_model.change_decoding_strategy(decoder_type=decoder_type)

    # make sure the model's decoding strategy is optimal
    decoding_cfg = asr_model.cfg.decoding
    with open_dict(decoding_cfg):
        # save time by doing greedy decoding and not trying to record the alignments
        decoding_cfg.strategy = "greedy"
        decoding_cfg.preserve_alignments = False
        if hasattr(asr_model, 'joint'):  # if an RNNT model
            # restrict max_symbols to make sure not stuck in infinite loop
            decoding_cfg.greedy.max_symbols = 10
            # sensible default parameter, but not necessary since batch size is 1
            decoding_cfg.fused_batch_size = -1
        asr_model.change_decoding_strategy(decoding_cfg)


    # set model to eval mode
    asr_model.eval()


    # get parameters to use as the initial cache state
    cache_last_channel, cache_last_time, cache_last_channel_len = asr_model.encoder.get_initial_cache_state(
        batch_size=1
    )


    global preprocessor
    preprocessor = init_preprocessor(asr_model)

    # Initialize global variables for caching
    global pre_encode_cache_size, cache_pre_encode
    pre_encode_cache_size = asr_model.encoder.streaming_cfg.pre_encode_cache_size[1]
    cache_pre_encode = torch.zeros((1, asr_model.cfg.preprocessor.features, pre_encode_cache_size), device=asr_model.device)

    async with websockets.serve(audio_consumer, "localhost", 8765):
        await asyncio.Future()  # Run server forever

asyncio.run(start_server())

And here is my client side code:


import asyncio
import websockets
import json
import numpy as np
import pyaudio
import struct

# Set the sample rate and chunk size
SAMPLE_RATE = 16000
lookahead_size = 80
ENCODER_STEP_LENGTH = 80
chunk_size_ms = lookahead_size + ENCODER_STEP_LENGTH
chunk_size_samples = int(SAMPLE_RATE * chunk_size_ms / 1000) - 1

async def send_audio_stream(file_path, websocket):
    with open(file_path, 'rb') as audio_file:
        while True:
            audio_chunk = audio_file.read(chunk_size_samples * 2)  # 2 bytes per sample for 16-bit audio
            if not audio_chunk:
                break

            await websocket.send(audio_chunk)
            response = await websocket.recv()
            transcription = json.loads(response)
            print("Transcription:", transcription['transcription'])



async def main():
    # WebSocket server address
    uri = "ws://localhost:8765"

    # File path of the audio to stream
    audio_file_path = "test-audio/my_audio_file.wav"

    async with websockets.connect(uri) as websocket:
        await send_audio_stream(audio_file_path, websocket)

# Run the main function
asyncio.run(main())

Here is the error I get on server side:

connection handler failed
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/site-packages/websockets/legacy/server.py", line 236, in handler
    await self.ws_handler(self)
  File "/usr/local/lib/python3.10/site-packages/websockets/legacy/server.py", line 1175, in _ws_handler
    return await cast(
  File "/home/apps/ASR/server_streaming.py", line 142, in audio_consumer
    transcription = transcribe_chunk(audio_chunk)
  File "/home/apps/ASR/server_streaming.py", line 112, in transcribe_chunk
    ) = asr_model.conformer_stream_step(
  File "/home/apps/ASR/nemo/collections/asr/parts/mixins/mixins.py", line 676, in conformer_stream_step
    best_hyp, all_hyp_or_transcribed_texts = self.decoding.rnnt_decoder_predictions_tensor(
  File "/home/apps/ASR/nemo/collections/asr/parts/submodules/rnnt_decoding.py", line 455, in rnnt_decoder_predictions_tensor
    hypotheses_list = self.decoding(
  File "/home/apps/ASR/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py", line 180, in __call__
    return self.forward(*args, **kwargs)
  File "/home/apps/ASR/nemo/core/classes/common.py", line 1098, in __call__
    outputs = wrapped(*args, **kwargs)
  File "/home/apps/ASR/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py", line 388, in forward
    hypothesis = self._greedy_decode(inseq, logitlen, partial_hypotheses=partial_hypothesis)
  File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/apps/ASR/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py", line 431, in _greedy_decode
    f = x.narrow(dim=0, start=time_idx, length=1)
RuntimeError: start (4) + length (1) exceeds dimension size (4).

Any ideas please? Thank you!

lucgeo avatar May 14 '24 07:05 lucgeo