Changed stt to parakeet — still experiemntal, though performance seems to be better
This commit is contained in:
209
stt/parakeet_transcriber.py
Normal file
209
stt/parakeet_transcriber.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
NVIDIA Parakeet TDT Transcriber
|
||||
|
||||
Real-time streaming ASR using NVIDIA's Parakeet TDT (Token-and-Duration Transducer) model.
|
||||
Supports streaming transcription with word-level timestamps for LLM pre-computation.
|
||||
|
||||
Model: nvidia/parakeet-tdt-0.6b-v3
|
||||
- 600M parameters
|
||||
- Real-time capable on GPU
|
||||
- Word-level timestamps
|
||||
- Streaming support via NVIDIA NeMo
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from nemo.collections.asr.models import EncDecRNNTBPEModel
|
||||
from typing import Optional, List, Dict
|
||||
import logging
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
logger = logging.getLogger('parakeet')
|
||||
|
||||
|
||||
class ParakeetTranscriber:
|
||||
"""
|
||||
NVIDIA Parakeet-based streaming transcription with word-level tokens.
|
||||
|
||||
Uses NVIDIA NeMo for proper model loading and inference.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "nvidia/parakeet-tdt-0.6b-v3",
|
||||
device: str = "cuda",
|
||||
language: str = "en"
|
||||
):
|
||||
"""
|
||||
Initialize Parakeet transcriber.
|
||||
|
||||
Args:
|
||||
model_name: HuggingFace model identifier
|
||||
device: Device to run on (cuda or cpu)
|
||||
language: Language code (Parakeet primarily supports English)
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
self.language = language
|
||||
|
||||
logger.info(f"Loading Parakeet model: {model_name} on {device}...")
|
||||
|
||||
# Load model via NeMo from HuggingFace
|
||||
self.model = EncDecRNNTBPEModel.from_pretrained(
|
||||
model_name=model_name,
|
||||
map_location=device
|
||||
)
|
||||
|
||||
self.model.eval()
|
||||
if device == "cuda":
|
||||
self.model = self.model.cuda()
|
||||
|
||||
# Thread pool for blocking transcription calls
|
||||
self.executor = ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
logger.info(f"Parakeet model loaded on {device}")
|
||||
|
||||
async def transcribe_async(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
sample_rate: int = 16000,
|
||||
return_timestamps: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Transcribe audio asynchronously (non-blocking).
|
||||
|
||||
Args:
|
||||
audio: Audio data as numpy array (float32)
|
||||
sample_rate: Audio sample rate (Parakeet expects 16kHz)
|
||||
return_timestamps: Whether to return word-level timestamps
|
||||
|
||||
Returns:
|
||||
Transcribed text (or dict with timestamps if return_timestamps=True)
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Run transcription in thread pool to avoid blocking
|
||||
result = await loop.run_in_executor(
|
||||
self.executor,
|
||||
self._transcribe_blocking,
|
||||
audio,
|
||||
sample_rate,
|
||||
return_timestamps
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _transcribe_blocking(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
sample_rate: int,
|
||||
return_timestamps: bool
|
||||
):
|
||||
"""
|
||||
Blocking transcription call (runs in thread pool).
|
||||
"""
|
||||
# Convert to float32 if needed
|
||||
if audio.dtype != np.float32:
|
||||
audio = audio.astype(np.float32) / 32768.0
|
||||
|
||||
# Ensure correct sample rate (Parakeet expects 16kHz)
|
||||
if sample_rate != 16000:
|
||||
logger.warning(f"Audio sample rate is {sample_rate}Hz, Parakeet expects 16kHz. Resampling...")
|
||||
import torchaudio
|
||||
audio_tensor = torch.from_numpy(audio).unsqueeze(0)
|
||||
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
|
||||
audio_tensor = resampler(audio_tensor)
|
||||
audio = audio_tensor.squeeze(0).numpy()
|
||||
sample_rate = 16000
|
||||
|
||||
# Transcribe using NeMo model
|
||||
with torch.no_grad():
|
||||
# Convert to tensor
|
||||
audio_signal = torch.from_numpy(audio).unsqueeze(0)
|
||||
audio_signal_len = torch.tensor([len(audio)])
|
||||
|
||||
if self.device == "cuda":
|
||||
audio_signal = audio_signal.cuda()
|
||||
audio_signal_len = audio_signal_len.cuda()
|
||||
|
||||
# Get transcription with timestamps
|
||||
# NeMo returns list of Hypothesis objects when timestamps=True
|
||||
transcriptions = self.model.transcribe(
|
||||
audio=[audio_signal.squeeze(0).cpu().numpy()],
|
||||
batch_size=1,
|
||||
timestamps=True # Enable timestamps to get word-level data
|
||||
)
|
||||
|
||||
# Extract text from Hypothesis object
|
||||
hypothesis = transcriptions[0] if transcriptions else None
|
||||
if hypothesis is None:
|
||||
text = ""
|
||||
words = []
|
||||
else:
|
||||
# Hypothesis object has .text attribute
|
||||
text = hypothesis.text.strip() if hasattr(hypothesis, 'text') else str(hypothesis).strip()
|
||||
|
||||
# Extract word-level timestamps if available
|
||||
words = []
|
||||
if hasattr(hypothesis, 'timestamp') and hypothesis.timestamp:
|
||||
# timestamp is a dict with 'word' key containing list of word timestamps
|
||||
word_timestamps = hypothesis.timestamp.get('word', [])
|
||||
for word_info in word_timestamps:
|
||||
words.append({
|
||||
"word": word_info.get('word', ''),
|
||||
"start_time": word_info.get('start', 0.0),
|
||||
"end_time": word_info.get('end', 0.0)
|
||||
})
|
||||
|
||||
logger.debug(f"Transcribed: '{text}' with {len(words)} words")
|
||||
|
||||
if return_timestamps:
|
||||
return {
|
||||
"text": text,
|
||||
"words": words
|
||||
}
|
||||
else:
|
||||
return text
|
||||
|
||||
async def transcribe_streaming(
|
||||
self,
|
||||
audio_chunks: List[np.ndarray],
|
||||
sample_rate: int = 16000,
|
||||
chunk_size_ms: int = 500
|
||||
) -> Dict[str, any]:
|
||||
"""
|
||||
Transcribe audio chunks with streaming support.
|
||||
|
||||
Args:
|
||||
audio_chunks: List of audio chunks to process
|
||||
sample_rate: Audio sample rate
|
||||
chunk_size_ms: Size of each chunk in milliseconds
|
||||
|
||||
Returns:
|
||||
Dict with partial and word-level results
|
||||
"""
|
||||
if not audio_chunks:
|
||||
return {"text": "", "words": []}
|
||||
|
||||
# Concatenate all chunks
|
||||
audio_data = np.concatenate(audio_chunks)
|
||||
|
||||
# Transcribe with timestamps for streaming
|
||||
result = await self.transcribe_async(
|
||||
audio_data,
|
||||
sample_rate,
|
||||
return_timestamps=True
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_supported_languages(self) -> List[str]:
|
||||
"""Get list of supported language codes."""
|
||||
# Parakeet TDT v3 primarily supports English
|
||||
return ["en"]
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources."""
|
||||
self.executor.shutdown(wait=True)
|
||||
logger.info("Parakeet transcriber cleaned up")
|
||||
Reference in New Issue
Block a user