ASR Context Biasing for EncDecHybridRNNTCTCModel (parakeet tdt 0.6b v3)
Hi there, i was looking for a method to perform context biasing based on the example notebook on the Parakeet tdt 0.6b v3 model.
I added the EncDecHybridRNNTCTCModel but got an no attribute "ctc_decoder" error.
This discussion mentioned that this type of model does not have the ctc_decoder.
Any pointers whether context biasing (or different word boosting technique without fine tuning) is possible?
thanks a lot.
@sandorkonya With Parakeet-TDT-0.6b-v3 you can use new phrase boosing via setting
python examples/asr/transcribe_speech.py \
<model params> \
rnnt_decoding.strategy="greedy_batch" \
rnnt_decoding.greedy.boosting_tree.key_phrases_file=${KEY_WORDS_LIST} \
rnnt_decoding.greedy.boosting_tree.context_score=1.0 \
rnnt_decoding.greedy.boosting_tree.depth_scaling=2.0 \
rnnt_decoding.greedy.boosting_tree_alpha=${BT_ALPHA}
See details in https://github.com/NVIDIA-NeMo/NeMo/pull/14277
@artbataev is this technique workd on CPU device ?
@abentabib Yes, it works on both CPUs and GPUs. A pure PyTorch implementation is used, when Triton/CUDA are unavailable (on GPU we can use a more efficient Triton kernel).
I am trying to create a websocket server that use TDT 0.6 v3 parakeet... (using exemple from : https://github.com/NVIDIA-NeMo/NeMo/pull/14759/commits/a42415de9cda7a8882f73d0f5387d8e5c4822a11)
import asyncio
import websockets
import numpy as np
import torch
from omegaconf import open_dict
from nemo.collections.asr.models import ASRModel, EncDecRNNTModel
from nemo.collections.asr.parts.utils.streaming_utils import StreamingBatchedAudioBuffer, ContextSize
from nemo.collections.asr.parts.utils.rnnt_utils import batched_hyps_to_hypotheses
# -----------------------------
# --- CONFIG MODÈLE & DEVICE ---
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_dtype = torch.float32
# Charger et configurer le modèle une seule fois
asr_model_name = "nvidia/parakeet-tdt-0.6b-v3"
asr_model: ASRModel = ASRModel.from_pretrained(asr_model_name).to(device).eval()
asr_model = asr_model.to(compute_dtype)
# Config décodage streaming
decoding_cfg = asr_model.cfg.decoding
with open_dict(decoding_cfg):
decoding_cfg.strategy = "greedy_batch"
decoding_cfg.greedy.loop_labels = True
decoding_cfg.greedy.preserve_alignments = False
decoding_cfg.fused_batch_size = -1
decoding_cfg.beam.return_best_hypothesis = True
# --- Hot words / boosting tree
# Mettre ton fichier texte avec les hot words
if "boosting_tree" not in decoding_cfg.greedy:
decoding_cfg.greedy.boosting_tree = {}
decoding_cfg.greedy.boosting_tree.key_phrases_file = "keywords.txt"
decoding_cfg.greedy.boosting_tree.context_score = 1.0
decoding_cfg.greedy.boosting_tree.depth_scaling = 2.0
decoding_cfg.greedy.boosting_tree_alpha = 0.5
# Appliquer la config
if hasattr(asr_model, "cur_decoder"):
asr_model.change_decoding_strategy(decoding_cfg, decoder_type="rnnt")
elif isinstance(asr_model, EncDecRNNTModel):
asr_model.change_decoding_strategy(decoding_cfg)
else:
raise ValueError(f"Unsupported model type: {type(asr_model)}")
# Préproc streaming
asr_model.preprocessor.featurizer.dither = 0.0
asr_model.preprocessor.featurizer.pad_to = 0
# -----------------------------
# --- PARAMS STREAMING ---
# -----------------------------
chunk_secs = 2
left_context_secs = 10
right_context_secs = 2
sample_rate = asr_model.cfg.preprocessor.sample_rate
feature_stride_sec = asr_model.cfg.preprocessor.window_stride
features_per_sec = 1.0 / feature_stride_sec
encoder_subsampling = asr_model.encoder.subsampling_factor
frame2samples = int(sample_rate * feature_stride_sec)
frame2samples = (frame2samples // encoder_subsampling) * encoder_subsampling
encoder_frame2audio_samples = frame2samples * encoder_subsampling
context_encoder_frames = ContextSize(
left=int(left_context_secs * features_per_sec / encoder_subsampling),
chunk=int(chunk_secs * features_per_sec / encoder_subsampling),
right=int(right_context_secs * features_per_sec / encoder_subsampling),
)
context_samples = ContextSize(
left=context_encoder_frames.left * encoder_subsampling * frame2samples,
chunk=context_encoder_frames.chunk * encoder_subsampling * frame2samples,
right=context_encoder_frames.right * encoder_subsampling * frame2samples,
)
# -----------------------------
# --- SESSION CLIENT ---
# -----------------------------
class StreamingSessionClient:
def __init__(self):
self.audio_frames = np.zeros([0], dtype=np.float32)
self.batched_audio_buffer = StreamingBatchedAudioBuffer(
batch_size=1,
context_samples=context_samples,
dtype=torch.float32,
device=device,
)
self.first_chunk_processed = False
self.state = None
self.hyp = None
self.fixed_transcription = ""
self.temporary_transcription = ""
@property
def transcription(self):
if self.temporary_transcription:
return f"{self.fixed_transcription} [{self.temporary_transcription}]"
return self.fixed_transcription
@torch.inference_mode()
def process_audio_chunk(self, audio_chunk: np.ndarray, is_last=False):
self.audio_frames = np.concatenate((self.audio_frames, audio_chunk))
first_chunk_samples = context_samples.chunk + context_samples.right
need_samples = context_samples.chunk if self.first_chunk_processed else first_chunk_samples
while (self.audio_frames.shape[0] >= need_samples) or (is_last and self.audio_frames.shape[0] > 0):
cur_chunk = self.audio_frames[:need_samples]
self._process_next_chunk(cur_chunk, is_last=is_last and self.audio_frames.shape[0] <= need_samples)
self.audio_frames = self.audio_frames[need_samples:].copy()
need_samples = context_samples.chunk
if not self.first_chunk_processed and self.audio_frames.shape[0] > 0:
self._process_first_temporary_chunk(self.audio_frames)
@torch.inference_mode()
def _process_first_temporary_chunk(self, audio_chunk: np.ndarray):
audio_tensor = torch.from_numpy(audio_chunk).float().unsqueeze(0).to(device)
length = torch.tensor([len(audio_chunk)], device=device)
enc_out, enc_len = asr_model(input_signal=audio_tensor, input_signal_length=length)
enc_out = enc_out.transpose(1, 2)
hyps, _, _ = asr_model.decoding.decoding.decoding_computer(
x=enc_out, out_len=enc_len, prev_batched_state=None
)
hyp = batched_hyps_to_hypotheses(hyps, batch_size=1)[0]
self.temporary_transcription = asr_model.tokenizer.ids_to_text(hyp.y_sequence.tolist())
@torch.inference_mode()
def _process_next_chunk(self, audio_chunk: np.ndarray, is_last: bool):
audio_tensor = torch.from_numpy(audio_chunk).float().unsqueeze(0).to(device)
length = torch.tensor([len(audio_chunk)], device=device)
self.batched_audio_buffer.add_audio_batch_(
audio_tensor,
audio_lengths=length,
is_last_chunk=is_last,
is_last_chunk_batch=torch.tensor([is_last], device=device),
)
enc_out, enc_len = asr_model(
input_signal=self.batched_audio_buffer.samples,
input_signal_length=self.batched_audio_buffer.context_size_batch.total(),
)
enc_out = enc_out.transpose(1, 2)
encoder_context = self.batched_audio_buffer.context_size.subsample(factor=encoder_frame2audio_samples)
encoder_context_batch = self.batched_audio_buffer.context_size_batch.subsample(factor=encoder_frame2audio_samples)
enc_out = enc_out[:, encoder_context.left:]
if encoder_context.chunk > 0:
hyps, _, self.state = asr_model.decoding.decoding.decoding_computer(
x=enc_out, out_len=encoder_context_batch.chunk, prev_batched_state=self.state
)
hyp = batched_hyps_to_hypotheses(hyps, batch_size=1)[0]
if self.hyp is None:
self.hyp = hyp
else:
self.hyp.merge_(hyp)
self.hyp.text = asr_model.tokenizer.ids_to_text(self.hyp.y_sequence.tolist())
self.fixed_transcription = self.hyp.text
if encoder_context.right > 0:
enc_out_right = enc_out[:, encoder_context.chunk:]
hyps_right, _, _ = asr_model.decoding.decoding.decoding_computer(
x=enc_out_right, out_len=encoder_context_batch.right, prev_batched_state=self.state
)
tmp_hyp = batched_hyps_to_hypotheses(hyps_right, batch_size=1)[0]
self.temporary_transcription = asr_model.tokenizer.ids_to_text(tmp_hyp.y_sequence.tolist())
else:
self.temporary_transcription = ""
self.first_chunk_processed = True
if is_last:
self.audio_frames = np.zeros([0], dtype=np.float32)
# -----------------------------
# --- SERVEUR WEBSOCKET ---
# -----------------------------
async def handler(ws):
session = StreamingSessionClient()
print("Client connecté")
try:
async for message in ws:
audio = np.frombuffer(message, dtype=np.int16).astype(np.float32) / 32768.0
session.process_audio_chunk(audio)
await ws.send(session.transcription)
except websockets.ConnectionClosed:
print("Client déconnecté")
async def main():
server = await websockets.serve(handler, "0.0.0.0", 8765)
print("Serveur WebSocket démarré sur ws://0.0.0.0:8765")
await server.wait_closed()
if __name__ == "__main__":
asyncio.run(main())
But when I say "Isabelle Ray Coquard", the TDT seems not using my keywords.txt for correcting the transcription I got "Isabbelle Récocard" even if i explictly put the correct version in the keywords.txt...
Is there something i am doing wrong ?
At first glance, everything should work.
Several notes:
-
nvidia/parakeet-tdt-0.6b-v3is a case-sensitive model, sokeywords.txtshould contain words in the desired case (maybe multiple spellings) -
boosting_tree_alpha = 0.5value looks very small, especially for greedy decoding. We usually observe optimal values closer to1and even higher for greedy decoding (for beam search, it can be a bit smaller, but beam search is currently unavailable in streaming)
My suggestion for sanity check: try using exactly one phrase (that you are saying, and which is incorrect without boosting) in keywords.txt with a very high value for boosting_tree_alpha, e.g., 10 or even higher (to ensure it is force-boosted and nothing else can be recognized).