NeMo icon indicating copy to clipboard operation
NeMo copied to clipboard

ASR Context Biasing for EncDecHybridRNNTCTCModel (parakeet tdt 0.6b v3)

Open sandorkonya opened this issue 3 months ago • 5 comments

Hi there, i was looking for a method to perform context biasing based on the example notebook on the Parakeet tdt 0.6b v3 model.

I added the EncDecHybridRNNTCTCModel but got an no attribute "ctc_decoder" error.

This discussion mentioned that this type of model does not have the ctc_decoder.

Any pointers whether context biasing (or different word boosting technique without fine tuning) is possible?

thanks a lot.

sandorkonya avatar Sep 21 '25 22:09 sandorkonya

@sandorkonya With Parakeet-TDT-0.6b-v3 you can use new phrase boosing via setting

python examples/asr/transcribe_speech.py \
    <model params> \
    rnnt_decoding.strategy="greedy_batch" \
    rnnt_decoding.greedy.boosting_tree.key_phrases_file=${KEY_WORDS_LIST} \
    rnnt_decoding.greedy.boosting_tree.context_score=1.0 \
    rnnt_decoding.greedy.boosting_tree.depth_scaling=2.0 \
    rnnt_decoding.greedy.boosting_tree_alpha=${BT_ALPHA}

See details in https://github.com/NVIDIA-NeMo/NeMo/pull/14277

artbataev avatar Sep 23 '25 16:09 artbataev

@artbataev is this technique workd on CPU device ?

abentabib avatar Dec 09 '25 15:12 abentabib

@abentabib Yes, it works on both CPUs and GPUs. A pure PyTorch implementation is used, when Triton/CUDA are unavailable (on GPU we can use a more efficient Triton kernel).

artbataev avatar Dec 09 '25 15:12 artbataev

I am trying to create a websocket server that use TDT 0.6 v3 parakeet... (using exemple from : https://github.com/NVIDIA-NeMo/NeMo/pull/14759/commits/a42415de9cda7a8882f73d0f5387d8e5c4822a11)


import asyncio
import websockets
import numpy as np
import torch
from omegaconf import open_dict
from nemo.collections.asr.models import ASRModel, EncDecRNNTModel
from nemo.collections.asr.parts.utils.streaming_utils import StreamingBatchedAudioBuffer, ContextSize
from nemo.collections.asr.parts.utils.rnnt_utils import batched_hyps_to_hypotheses

# -----------------------------
# --- CONFIG MODÈLE & DEVICE ---
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_dtype = torch.float32

# Charger et configurer le modèle une seule fois
asr_model_name = "nvidia/parakeet-tdt-0.6b-v3"
asr_model: ASRModel = ASRModel.from_pretrained(asr_model_name).to(device).eval()
asr_model = asr_model.to(compute_dtype)

# Config décodage streaming
decoding_cfg = asr_model.cfg.decoding
with open_dict(decoding_cfg):
    decoding_cfg.strategy = "greedy_batch"
    decoding_cfg.greedy.loop_labels = True
    decoding_cfg.greedy.preserve_alignments = False
    decoding_cfg.fused_batch_size = -1
    decoding_cfg.beam.return_best_hypothesis = True

    # --- Hot words / boosting tree
    # Mettre ton fichier texte avec les hot words
    if "boosting_tree" not in decoding_cfg.greedy:
      decoding_cfg.greedy.boosting_tree = {}
    decoding_cfg.greedy.boosting_tree.key_phrases_file = "keywords.txt"
    decoding_cfg.greedy.boosting_tree.context_score = 1.0
    decoding_cfg.greedy.boosting_tree.depth_scaling = 2.0
    decoding_cfg.greedy.boosting_tree_alpha = 0.5

# Appliquer la config
if hasattr(asr_model, "cur_decoder"):
    asr_model.change_decoding_strategy(decoding_cfg, decoder_type="rnnt")
elif isinstance(asr_model, EncDecRNNTModel):
    asr_model.change_decoding_strategy(decoding_cfg)
else:
    raise ValueError(f"Unsupported model type: {type(asr_model)}")

# Préproc streaming
asr_model.preprocessor.featurizer.dither = 0.0
asr_model.preprocessor.featurizer.pad_to = 0

# -----------------------------
# --- PARAMS STREAMING ---
# -----------------------------
chunk_secs = 2
left_context_secs = 10
right_context_secs = 2

sample_rate = asr_model.cfg.preprocessor.sample_rate
feature_stride_sec = asr_model.cfg.preprocessor.window_stride
features_per_sec = 1.0 / feature_stride_sec
encoder_subsampling = asr_model.encoder.subsampling_factor

frame2samples = int(sample_rate * feature_stride_sec)
frame2samples = (frame2samples // encoder_subsampling) * encoder_subsampling
encoder_frame2audio_samples = frame2samples * encoder_subsampling

context_encoder_frames = ContextSize(
    left=int(left_context_secs * features_per_sec / encoder_subsampling),
    chunk=int(chunk_secs * features_per_sec / encoder_subsampling),
    right=int(right_context_secs * features_per_sec / encoder_subsampling),
)
context_samples = ContextSize(
    left=context_encoder_frames.left * encoder_subsampling * frame2samples,
    chunk=context_encoder_frames.chunk * encoder_subsampling * frame2samples,
    right=context_encoder_frames.right * encoder_subsampling * frame2samples,
)

# -----------------------------
# --- SESSION CLIENT ---
# -----------------------------
class StreamingSessionClient:
    def __init__(self):
        self.audio_frames = np.zeros([0], dtype=np.float32)
        self.batched_audio_buffer = StreamingBatchedAudioBuffer(
            batch_size=1,
            context_samples=context_samples,
            dtype=torch.float32,
            device=device,
        )
        self.first_chunk_processed = False
        self.state = None
        self.hyp = None
        self.fixed_transcription = ""
        self.temporary_transcription = ""

    @property
    def transcription(self):
        if self.temporary_transcription:
            return f"{self.fixed_transcription} [{self.temporary_transcription}]"
        return self.fixed_transcription

    @torch.inference_mode()
    def process_audio_chunk(self, audio_chunk: np.ndarray, is_last=False):
        self.audio_frames = np.concatenate((self.audio_frames, audio_chunk))
        first_chunk_samples = context_samples.chunk + context_samples.right
        need_samples = context_samples.chunk if self.first_chunk_processed else first_chunk_samples

        while (self.audio_frames.shape[0] >= need_samples) or (is_last and self.audio_frames.shape[0] > 0):
            cur_chunk = self.audio_frames[:need_samples]
            self._process_next_chunk(cur_chunk, is_last=is_last and self.audio_frames.shape[0] <= need_samples)
            self.audio_frames = self.audio_frames[need_samples:].copy()
            need_samples = context_samples.chunk

        if not self.first_chunk_processed and self.audio_frames.shape[0] > 0:
            self._process_first_temporary_chunk(self.audio_frames)

    @torch.inference_mode()
    def _process_first_temporary_chunk(self, audio_chunk: np.ndarray):
        audio_tensor = torch.from_numpy(audio_chunk).float().unsqueeze(0).to(device)
        length = torch.tensor([len(audio_chunk)], device=device)
        enc_out, enc_len = asr_model(input_signal=audio_tensor, input_signal_length=length)
        enc_out = enc_out.transpose(1, 2)
        hyps, _, _ = asr_model.decoding.decoding.decoding_computer(
            x=enc_out, out_len=enc_len, prev_batched_state=None
        )
        hyp = batched_hyps_to_hypotheses(hyps, batch_size=1)[0]
        self.temporary_transcription = asr_model.tokenizer.ids_to_text(hyp.y_sequence.tolist())

    @torch.inference_mode()
    def _process_next_chunk(self, audio_chunk: np.ndarray, is_last: bool):
        audio_tensor = torch.from_numpy(audio_chunk).float().unsqueeze(0).to(device)
        length = torch.tensor([len(audio_chunk)], device=device)
        self.batched_audio_buffer.add_audio_batch_(
            audio_tensor,
            audio_lengths=length,
            is_last_chunk=is_last,
            is_last_chunk_batch=torch.tensor([is_last], device=device),
        )

        enc_out, enc_len = asr_model(
            input_signal=self.batched_audio_buffer.samples,
            input_signal_length=self.batched_audio_buffer.context_size_batch.total(),
        )
        enc_out = enc_out.transpose(1, 2)

        encoder_context = self.batched_audio_buffer.context_size.subsample(factor=encoder_frame2audio_samples)
        encoder_context_batch = self.batched_audio_buffer.context_size_batch.subsample(factor=encoder_frame2audio_samples)
        enc_out = enc_out[:, encoder_context.left:]

        if encoder_context.chunk > 0:
            hyps, _, self.state = asr_model.decoding.decoding.decoding_computer(
                x=enc_out, out_len=encoder_context_batch.chunk, prev_batched_state=self.state
            )
            hyp = batched_hyps_to_hypotheses(hyps, batch_size=1)[0]
            if self.hyp is None:
                self.hyp = hyp
            else:
                self.hyp.merge_(hyp)
            self.hyp.text = asr_model.tokenizer.ids_to_text(self.hyp.y_sequence.tolist())
            self.fixed_transcription = self.hyp.text

        if encoder_context.right > 0:
            enc_out_right = enc_out[:, encoder_context.chunk:]
            hyps_right, _, _ = asr_model.decoding.decoding.decoding_computer(
                x=enc_out_right, out_len=encoder_context_batch.right, prev_batched_state=self.state
            )
            tmp_hyp = batched_hyps_to_hypotheses(hyps_right, batch_size=1)[0]
            self.temporary_transcription = asr_model.tokenizer.ids_to_text(tmp_hyp.y_sequence.tolist())
        else:
            self.temporary_transcription = ""

        self.first_chunk_processed = True
        if is_last:
            self.audio_frames = np.zeros([0], dtype=np.float32)

# -----------------------------
# --- SERVEUR WEBSOCKET ---
# -----------------------------
async def handler(ws):
    session = StreamingSessionClient()
    print("Client connecté")
    try:
        async for message in ws:
            audio = np.frombuffer(message, dtype=np.int16).astype(np.float32) / 32768.0
            session.process_audio_chunk(audio)
            await ws.send(session.transcription)
    except websockets.ConnectionClosed:
        print("Client déconnecté")


async def main():
    server = await websockets.serve(handler, "0.0.0.0", 8765)
    print("Serveur WebSocket démarré sur ws://0.0.0.0:8765")
    await server.wait_closed()


if __name__ == "__main__":
    asyncio.run(main())

But when I say "Isabelle Ray Coquard", the TDT seems not using my keywords.txt for correcting the transcription I got "Isabbelle Récocard" even if i explictly put the correct version in the keywords.txt...

Is there something i am doing wrong ?

abentabib avatar Dec 09 '25 18:12 abentabib

At first glance, everything should work.

Several notes:

  • nvidia/parakeet-tdt-0.6b-v3 is a case-sensitive model, so keywords.txt should contain words in the desired case (maybe multiple spellings)
  • boosting_tree_alpha = 0.5 value looks very small, especially for greedy decoding. We usually observe optimal values closer to 1 and even higher for greedy decoding (for beam search, it can be a bit smaller, but beam search is currently unavailable in streaming)

My suggestion for sanity check: try using exactly one phrase (that you are saying, and which is incorrect without boosting) in keywords.txt with a very high value for boosting_tree_alpha, e.g., 10 or even higher (to ensure it is force-boosted and nothing else can be recognized).

artbataev avatar Dec 09 '25 19:12 artbataev