""" NVIDIA Parakeet TDT Transcriber Real-time streaming ASR using NVIDIA's Parakeet TDT (Token-and-Duration Transducer) model. Supports streaming transcription with word-level timestamps for LLM pre-computation. Model: nvidia/parakeet-tdt-0.6b-v3 - 600M parameters - Real-time capable on GPU - Word-level timestamps - Streaming support via NVIDIA NeMo """ import numpy as np import torch from nemo.collections.asr.models import EncDecRNNTBPEModel from typing import Optional, List, Dict import logging import asyncio from concurrent.futures import ThreadPoolExecutor logger = logging.getLogger('parakeet') class ParakeetTranscriber: """ NVIDIA Parakeet-based streaming transcription with word-level tokens. Uses NVIDIA NeMo for proper model loading and inference. """ def __init__( self, model_name: str = "nvidia/parakeet-tdt-0.6b-v3", device: str = "cuda", language: str = "en" ): """ Initialize Parakeet transcriber. Args: model_name: HuggingFace model identifier device: Device to run on (cuda or cpu) language: Language code (Parakeet primarily supports English) """ self.model_name = model_name self.device = device self.language = language logger.info(f"Loading Parakeet model: {model_name} on {device}...") # Load model via NeMo from HuggingFace self.model = EncDecRNNTBPEModel.from_pretrained( model_name=model_name, map_location=device ) self.model.eval() if device == "cuda": self.model = self.model.cuda() # Thread pool for blocking transcription calls self.executor = ThreadPoolExecutor(max_workers=2) logger.info(f"Parakeet model loaded on {device}") async def transcribe_async( self, audio: np.ndarray, sample_rate: int = 16000, return_timestamps: bool = False ) -> str: """ Transcribe audio asynchronously (non-blocking). Args: audio: Audio data as numpy array (float32) sample_rate: Audio sample rate (Parakeet expects 16kHz) return_timestamps: Whether to return word-level timestamps Returns: Transcribed text (or dict with timestamps if return_timestamps=True) """ loop = asyncio.get_event_loop() # Run transcription in thread pool to avoid blocking result = await loop.run_in_executor( self.executor, self._transcribe_blocking, audio, sample_rate, return_timestamps ) return result def _transcribe_blocking( self, audio: np.ndarray, sample_rate: int, return_timestamps: bool ): """ Blocking transcription call (runs in thread pool). """ # Convert to float32 if needed if audio.dtype != np.float32: audio = audio.astype(np.float32) / 32768.0 # Ensure correct sample rate (Parakeet expects 16kHz) if sample_rate != 16000: logger.warning(f"Audio sample rate is {sample_rate}Hz, Parakeet expects 16kHz. Resampling...") import torchaudio audio_tensor = torch.from_numpy(audio).unsqueeze(0) resampler = torchaudio.transforms.Resample(sample_rate, 16000) audio_tensor = resampler(audio_tensor) audio = audio_tensor.squeeze(0).numpy() sample_rate = 16000 # Transcribe using NeMo model with torch.no_grad(): # Convert to tensor audio_signal = torch.from_numpy(audio).unsqueeze(0) audio_signal_len = torch.tensor([len(audio)]) if self.device == "cuda": audio_signal = audio_signal.cuda() audio_signal_len = audio_signal_len.cuda() # Get transcription with timestamps # NeMo returns list of Hypothesis objects when timestamps=True transcriptions = self.model.transcribe( audio=[audio_signal.squeeze(0).cpu().numpy()], batch_size=1, timestamps=True # Enable timestamps to get word-level data ) # Extract text from Hypothesis object hypothesis = transcriptions[0] if transcriptions else None if hypothesis is None: text = "" words = [] else: # Hypothesis object has .text attribute text = hypothesis.text.strip() if hasattr(hypothesis, 'text') else str(hypothesis).strip() # Extract word-level timestamps if available words = [] if hasattr(hypothesis, 'timestamp') and hypothesis.timestamp: # timestamp is a dict with 'word' key containing list of word timestamps word_timestamps = hypothesis.timestamp.get('word', []) for word_info in word_timestamps: words.append({ "word": word_info.get('word', ''), "start_time": word_info.get('start', 0.0), "end_time": word_info.get('end', 0.0) }) logger.debug(f"Transcribed: '{text}' with {len(words)} words") if return_timestamps: return { "text": text, "words": words } else: return text async def transcribe_streaming( self, audio_chunks: List[np.ndarray], sample_rate: int = 16000, chunk_size_ms: int = 500 ) -> Dict[str, any]: """ Transcribe audio chunks with streaming support. Args: audio_chunks: List of audio chunks to process sample_rate: Audio sample rate chunk_size_ms: Size of each chunk in milliseconds Returns: Dict with partial and word-level results """ if not audio_chunks: return {"text": "", "words": []} # Concatenate all chunks audio_data = np.concatenate(audio_chunks) # Transcribe with timestamps for streaming result = await self.transcribe_async( audio_data, sample_rate, return_timestamps=True ) return result def get_supported_languages(self) -> List[str]: """Get list of supported language codes.""" # Parakeet TDT v3 primarily supports English return ["en"] def cleanup(self): """Cleanup resources.""" self.executor.shutdown(wait=True) logger.info("Parakeet transcriber cleaned up")