163 lines
5.1 KiB
Python
163 lines
5.1 KiB
Python
|
|
"""
|
||
|
|
ASR Pipeline using onnx-asr library with Parakeet TDT 0.6B V3 model
|
||
|
|
"""
|
||
|
|
import numpy as np
|
||
|
|
import onnx_asr
|
||
|
|
from typing import Union, Optional
|
||
|
|
import logging
|
||
|
|
|
||
|
|
logging.basicConfig(level=logging.INFO)
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class ASRPipeline:
|
||
|
|
"""
|
||
|
|
ASR Pipeline wrapper for onnx-asr Parakeet TDT model.
|
||
|
|
Supports GPU acceleration via ONNX Runtime with CUDA/TensorRT.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
model_name: str = "nemo-parakeet-tdt-0.6b-v3",
|
||
|
|
model_path: Optional[str] = None,
|
||
|
|
quantization: Optional[str] = None,
|
||
|
|
providers: Optional[list] = None,
|
||
|
|
use_vad: bool = False,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Initialize ASR Pipeline.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
model_name: Name of the model to load (default: "nemo-parakeet-tdt-0.6b-v3")
|
||
|
|
model_path: Optional local path to model files (default uses models/parakeet)
|
||
|
|
quantization: Optional quantization ("int8", "fp16", etc.)
|
||
|
|
providers: Optional ONNX runtime providers list for GPU acceleration
|
||
|
|
use_vad: Whether to use Voice Activity Detection
|
||
|
|
"""
|
||
|
|
self.model_name = model_name
|
||
|
|
self.model_path = model_path or "models/parakeet"
|
||
|
|
self.quantization = quantization
|
||
|
|
self.use_vad = use_vad
|
||
|
|
|
||
|
|
# Configure providers for GPU acceleration
|
||
|
|
if providers is None:
|
||
|
|
# Default: try CUDA, then CPU
|
||
|
|
providers = [
|
||
|
|
(
|
||
|
|
"CUDAExecutionProvider",
|
||
|
|
{
|
||
|
|
"device_id": 0,
|
||
|
|
"arena_extend_strategy": "kNextPowerOfTwo",
|
||
|
|
"gpu_mem_limit": 6 * 1024 * 1024 * 1024, # 6GB
|
||
|
|
"cudnn_conv_algo_search": "EXHAUSTIVE",
|
||
|
|
"do_copy_in_default_stream": True,
|
||
|
|
}
|
||
|
|
),
|
||
|
|
"CPUExecutionProvider",
|
||
|
|
]
|
||
|
|
|
||
|
|
self.providers = providers
|
||
|
|
logger.info(f"Initializing ASR Pipeline with model: {model_name}")
|
||
|
|
logger.info(f"Model path: {self.model_path}")
|
||
|
|
logger.info(f"Quantization: {quantization}")
|
||
|
|
logger.info(f"Providers: {providers}")
|
||
|
|
|
||
|
|
# Load the model
|
||
|
|
try:
|
||
|
|
self.model = onnx_asr.load_model(
|
||
|
|
model_name,
|
||
|
|
self.model_path,
|
||
|
|
quantization=quantization,
|
||
|
|
providers=providers,
|
||
|
|
)
|
||
|
|
logger.info("Model loaded successfully")
|
||
|
|
|
||
|
|
# Optionally add VAD
|
||
|
|
if use_vad:
|
||
|
|
logger.info("Loading VAD model...")
|
||
|
|
vad = onnx_asr.load_vad("silero", providers=providers)
|
||
|
|
self.model = self.model.with_vad(vad)
|
||
|
|
logger.info("VAD enabled")
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to load model: {e}")
|
||
|
|
raise
|
||
|
|
|
||
|
|
def transcribe(
|
||
|
|
self,
|
||
|
|
audio: Union[str, np.ndarray],
|
||
|
|
sample_rate: int = 16000,
|
||
|
|
) -> Union[str, list]:
|
||
|
|
"""
|
||
|
|
Transcribe audio to text.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
audio: Audio data as numpy array (float32) or path to WAV file
|
||
|
|
sample_rate: Sample rate of audio (default: 16000 Hz)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Transcribed text string, or list of results if VAD is enabled
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
if isinstance(audio, str):
|
||
|
|
# Load from file
|
||
|
|
result = self.model.recognize(audio)
|
||
|
|
else:
|
||
|
|
# Process numpy array
|
||
|
|
if audio.dtype != np.float32:
|
||
|
|
audio = audio.astype(np.float32)
|
||
|
|
result = self.model.recognize(audio, sample_rate=sample_rate)
|
||
|
|
|
||
|
|
# If VAD is enabled, result is a generator
|
||
|
|
if self.use_vad:
|
||
|
|
return list(result)
|
||
|
|
|
||
|
|
return result
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Transcription failed: {e}")
|
||
|
|
raise
|
||
|
|
|
||
|
|
def transcribe_batch(
|
||
|
|
self,
|
||
|
|
audio_files: list,
|
||
|
|
) -> list:
|
||
|
|
"""
|
||
|
|
Transcribe multiple audio files in batch.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
audio_files: List of paths to WAV files
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of transcribed text strings
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
results = self.model.recognize(audio_files)
|
||
|
|
return results
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Batch transcription failed: {e}")
|
||
|
|
raise
|
||
|
|
|
||
|
|
def transcribe_stream(
|
||
|
|
self,
|
||
|
|
audio_chunk: np.ndarray,
|
||
|
|
sample_rate: int = 16000,
|
||
|
|
) -> str:
|
||
|
|
"""
|
||
|
|
Transcribe streaming audio chunk.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
audio_chunk: Audio chunk as numpy array (float32)
|
||
|
|
sample_rate: Sample rate of audio
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Transcribed text for the chunk
|
||
|
|
"""
|
||
|
|
return self.transcribe(audio_chunk, sample_rate=sample_rate)
|
||
|
|
|
||
|
|
|
||
|
|
# Convenience function for backward compatibility
|
||
|
|
def load_pipeline(**kwargs) -> ASRPipeline:
|
||
|
|
"""Load and return ASR pipeline with given configuration."""
|
||
|
|
return ASRPipeline(**kwargs)
|