""" ASR Pipeline using onnx-asr library with Parakeet TDT 0.6B V3 model """ import numpy as np import onnx_asr from typing import Union, Optional import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ASRPipeline: """ ASR Pipeline wrapper for onnx-asr Parakeet TDT model. Supports GPU acceleration via ONNX Runtime with CUDA/TensorRT. """ def __init__( self, model_name: str = "nemo-parakeet-tdt-0.6b-v3", model_path: Optional[str] = None, quantization: Optional[str] = None, providers: Optional[list] = None, use_vad: bool = False, ): """ Initialize ASR Pipeline. Args: model_name: Name of the model to load (default: "nemo-parakeet-tdt-0.6b-v3") model_path: Optional local path to model files (default uses models/parakeet) quantization: Optional quantization ("int8", "fp16", etc.) providers: Optional ONNX runtime providers list for GPU acceleration use_vad: Whether to use Voice Activity Detection """ self.model_name = model_name self.model_path = model_path or "models/parakeet" self.quantization = quantization self.use_vad = use_vad # Configure providers for GPU acceleration if providers is None: # Default: try CUDA, then CPU providers = [ ( "CUDAExecutionProvider", { "device_id": 0, "arena_extend_strategy": "kNextPowerOfTwo", "gpu_mem_limit": 6 * 1024 * 1024 * 1024, # 6GB "cudnn_conv_algo_search": "EXHAUSTIVE", "do_copy_in_default_stream": True, } ), "CPUExecutionProvider", ] self.providers = providers logger.info(f"Initializing ASR Pipeline with model: {model_name}") logger.info(f"Model path: {self.model_path}") logger.info(f"Quantization: {quantization}") logger.info(f"Providers: {providers}") # Load the model try: self.model = onnx_asr.load_model( model_name, self.model_path, quantization=quantization, providers=providers, ) logger.info("Model loaded successfully") # Optionally add VAD if use_vad: logger.info("Loading VAD model...") vad = onnx_asr.load_vad("silero", providers=providers) self.model = self.model.with_vad(vad) logger.info("VAD enabled") except Exception as e: logger.error(f"Failed to load model: {e}") raise def transcribe( self, audio: Union[str, np.ndarray], sample_rate: int = 16000, ) -> Union[str, list]: """ Transcribe audio to text. Args: audio: Audio data as numpy array (float32) or path to WAV file sample_rate: Sample rate of audio (default: 16000 Hz) Returns: Transcribed text string, or list of results if VAD is enabled """ try: if isinstance(audio, str): # Load from file result = self.model.recognize(audio) else: # Process numpy array if audio.dtype != np.float32: audio = audio.astype(np.float32) result = self.model.recognize(audio, sample_rate=sample_rate) # If VAD is enabled, result is a generator if self.use_vad: return list(result) return result except Exception as e: logger.error(f"Transcription failed: {e}") raise def transcribe_batch( self, audio_files: list, ) -> list: """ Transcribe multiple audio files in batch. Args: audio_files: List of paths to WAV files Returns: List of transcribed text strings """ try: results = self.model.recognize(audio_files) return results except Exception as e: logger.error(f"Batch transcription failed: {e}") raise def transcribe_stream( self, audio_chunk: np.ndarray, sample_rate: int = 16000, ) -> str: """ Transcribe streaming audio chunk. Args: audio_chunk: Audio chunk as numpy array (float32) sample_rate: Sample rate of audio Returns: Transcribed text for the chunk """ return self.transcribe(audio_chunk, sample_rate=sample_rate) # Convenience function for backward compatibility def load_pipeline(**kwargs) -> ASRPipeline: """Load and return ASR pipeline with given configuration.""" return ASRPipeline(**kwargs)