2026-01-18 03:35:50 +02:00
|
|
|
"""
|
|
|
|
|
NVIDIA Parakeet TDT Transcriber
|
|
|
|
|
|
|
|
|
|
Real-time streaming ASR using NVIDIA's Parakeet TDT (Token-and-Duration Transducer) model.
|
|
|
|
|
Supports streaming transcription with word-level timestamps for LLM pre-computation.
|
|
|
|
|
|
|
|
|
|
Model: nvidia/parakeet-tdt-0.6b-v3
|
|
|
|
|
- 600M parameters
|
|
|
|
|
- Real-time capable on GPU
|
|
|
|
|
- Word-level timestamps
|
|
|
|
|
- Streaming support via NVIDIA NeMo
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
from nemo.collections.asr.models import EncDecRNNTBPEModel
|
|
|
|
|
from typing import Optional, List, Dict
|
|
|
|
|
import logging
|
|
|
|
|
import asyncio
|
|
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger('parakeet')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParakeetTranscriber:
|
|
|
|
|
"""
|
|
|
|
|
NVIDIA Parakeet-based streaming transcription with word-level tokens.
|
|
|
|
|
|
|
|
|
|
Uses NVIDIA NeMo for proper model loading and inference.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
model_name: str = "nvidia/parakeet-tdt-0.6b-v3",
|
|
|
|
|
device: str = "cuda",
|
|
|
|
|
language: str = "en"
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Initialize Parakeet transcriber.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
model_name: HuggingFace model identifier
|
|
|
|
|
device: Device to run on (cuda or cpu)
|
|
|
|
|
language: Language code (Parakeet primarily supports English)
|
|
|
|
|
"""
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
self.device = device
|
|
|
|
|
self.language = language
|
|
|
|
|
|
|
|
|
|
logger.info(f"Loading Parakeet model: {model_name} on {device}...")
|
|
|
|
|
|
2026-01-20 23:06:17 +02:00
|
|
|
# Set PyTorch memory allocator settings for better memory management
|
|
|
|
|
if device == "cuda":
|
|
|
|
|
# Enable expandable segments to reduce fragmentation
|
|
|
|
|
import os
|
|
|
|
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
|
|
|
|
|
|
|
|
|
# Clear cache before loading model
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
2026-01-18 03:35:50 +02:00
|
|
|
# Load model via NeMo from HuggingFace
|
|
|
|
|
self.model = EncDecRNNTBPEModel.from_pretrained(
|
|
|
|
|
model_name=model_name,
|
|
|
|
|
map_location=device
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.model.eval()
|
|
|
|
|
if device == "cuda":
|
|
|
|
|
self.model = self.model.cuda()
|
2026-01-20 23:06:17 +02:00
|
|
|
# Enable memory efficient attention if available
|
|
|
|
|
try:
|
|
|
|
|
self.model.encoder.use_memory_efficient_attention = True
|
|
|
|
|
except:
|
|
|
|
|
pass
|
2026-01-18 03:35:50 +02:00
|
|
|
|
|
|
|
|
# Thread pool for blocking transcription calls
|
|
|
|
|
self.executor = ThreadPoolExecutor(max_workers=2)
|
|
|
|
|
|
|
|
|
|
logger.info(f"Parakeet model loaded on {device}")
|
|
|
|
|
|
|
|
|
|
async def transcribe_async(
|
|
|
|
|
self,
|
|
|
|
|
audio: np.ndarray,
|
|
|
|
|
sample_rate: int = 16000,
|
|
|
|
|
return_timestamps: bool = False
|
|
|
|
|
) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Transcribe audio asynchronously (non-blocking).
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
audio: Audio data as numpy array (float32)
|
|
|
|
|
sample_rate: Audio sample rate (Parakeet expects 16kHz)
|
|
|
|
|
return_timestamps: Whether to return word-level timestamps
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Transcribed text (or dict with timestamps if return_timestamps=True)
|
|
|
|
|
"""
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
|
|
|
|
|
# Run transcription in thread pool to avoid blocking
|
|
|
|
|
result = await loop.run_in_executor(
|
|
|
|
|
self.executor,
|
|
|
|
|
self._transcribe_blocking,
|
|
|
|
|
audio,
|
|
|
|
|
sample_rate,
|
|
|
|
|
return_timestamps
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
def _transcribe_blocking(
|
|
|
|
|
self,
|
|
|
|
|
audio: np.ndarray,
|
|
|
|
|
sample_rate: int,
|
|
|
|
|
return_timestamps: bool
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Blocking transcription call (runs in thread pool).
|
|
|
|
|
"""
|
|
|
|
|
# Convert to float32 if needed
|
|
|
|
|
if audio.dtype != np.float32:
|
|
|
|
|
audio = audio.astype(np.float32) / 32768.0
|
|
|
|
|
|
|
|
|
|
# Ensure correct sample rate (Parakeet expects 16kHz)
|
|
|
|
|
if sample_rate != 16000:
|
|
|
|
|
logger.warning(f"Audio sample rate is {sample_rate}Hz, Parakeet expects 16kHz. Resampling...")
|
|
|
|
|
import torchaudio
|
|
|
|
|
audio_tensor = torch.from_numpy(audio).unsqueeze(0)
|
|
|
|
|
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
|
|
|
|
|
audio_tensor = resampler(audio_tensor)
|
|
|
|
|
audio = audio_tensor.squeeze(0).numpy()
|
|
|
|
|
sample_rate = 16000
|
|
|
|
|
|
|
|
|
|
# Transcribe using NeMo model
|
|
|
|
|
with torch.no_grad():
|
2026-01-20 23:06:17 +02:00
|
|
|
# Convert to tensor and keep on GPU to avoid CPU/GPU bouncing
|
2026-01-18 03:35:50 +02:00
|
|
|
audio_signal = torch.from_numpy(audio).unsqueeze(0)
|
|
|
|
|
audio_signal_len = torch.tensor([len(audio)])
|
|
|
|
|
|
|
|
|
|
if self.device == "cuda":
|
|
|
|
|
audio_signal = audio_signal.cuda()
|
|
|
|
|
audio_signal_len = audio_signal_len.cuda()
|
|
|
|
|
|
2026-01-20 23:06:17 +02:00
|
|
|
# Get transcription
|
|
|
|
|
# NeMo returns list of Hypothesis objects
|
|
|
|
|
# Note: timestamps=True causes significant VRAM usage (~1-2GB extra)
|
|
|
|
|
# Only enable for final transcriptions, not streaming partials
|
2026-01-18 03:35:50 +02:00
|
|
|
transcriptions = self.model.transcribe(
|
2026-01-20 23:06:17 +02:00
|
|
|
audio=[audio], # Pass NumPy array directly (NeMo handles it efficiently)
|
2026-01-18 03:35:50 +02:00
|
|
|
batch_size=1,
|
2026-01-20 23:06:17 +02:00
|
|
|
timestamps=return_timestamps # Only use timestamps when explicitly requested
|
2026-01-18 03:35:50 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Extract text from Hypothesis object
|
|
|
|
|
hypothesis = transcriptions[0] if transcriptions else None
|
|
|
|
|
if hypothesis is None:
|
|
|
|
|
text = ""
|
|
|
|
|
words = []
|
|
|
|
|
else:
|
|
|
|
|
# Hypothesis object has .text attribute
|
|
|
|
|
text = hypothesis.text.strip() if hasattr(hypothesis, 'text') else str(hypothesis).strip()
|
|
|
|
|
|
2026-01-20 23:06:17 +02:00
|
|
|
# Extract word-level timestamps if available and requested
|
2026-01-18 03:35:50 +02:00
|
|
|
words = []
|
2026-01-20 23:06:17 +02:00
|
|
|
if return_timestamps and hasattr(hypothesis, 'timestamp') and hypothesis.timestamp:
|
2026-01-18 03:35:50 +02:00
|
|
|
# timestamp is a dict with 'word' key containing list of word timestamps
|
|
|
|
|
word_timestamps = hypothesis.timestamp.get('word', [])
|
|
|
|
|
for word_info in word_timestamps:
|
|
|
|
|
words.append({
|
|
|
|
|
"word": word_info.get('word', ''),
|
|
|
|
|
"start_time": word_info.get('start', 0.0),
|
|
|
|
|
"end_time": word_info.get('end', 0.0)
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
logger.debug(f"Transcribed: '{text}' with {len(words)} words")
|
|
|
|
|
|
|
|
|
|
if return_timestamps:
|
|
|
|
|
return {
|
|
|
|
|
"text": text,
|
|
|
|
|
"words": words
|
|
|
|
|
}
|
|
|
|
|
else:
|
|
|
|
|
return text
|
2026-01-20 23:06:17 +02:00
|
|
|
|
|
|
|
|
# Note: We do NOT call torch.cuda.empty_cache() here
|
|
|
|
|
# That breaks PyTorch's memory allocator and causes fragmentation
|
|
|
|
|
# Let PyTorch manage its own memory pool
|
2026-01-18 03:35:50 +02:00
|
|
|
|
|
|
|
|
async def transcribe_streaming(
|
|
|
|
|
self,
|
|
|
|
|
audio_chunks: List[np.ndarray],
|
|
|
|
|
sample_rate: int = 16000,
|
|
|
|
|
chunk_size_ms: int = 500
|
|
|
|
|
) -> Dict[str, any]:
|
|
|
|
|
"""
|
|
|
|
|
Transcribe audio chunks with streaming support.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
audio_chunks: List of audio chunks to process
|
|
|
|
|
sample_rate: Audio sample rate
|
|
|
|
|
chunk_size_ms: Size of each chunk in milliseconds
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Dict with partial and word-level results
|
|
|
|
|
"""
|
|
|
|
|
if not audio_chunks:
|
|
|
|
|
return {"text": "", "words": []}
|
|
|
|
|
|
|
|
|
|
# Concatenate all chunks
|
|
|
|
|
audio_data = np.concatenate(audio_chunks)
|
|
|
|
|
|
|
|
|
|
# Transcribe with timestamps for streaming
|
|
|
|
|
result = await self.transcribe_async(
|
|
|
|
|
audio_data,
|
|
|
|
|
sample_rate,
|
|
|
|
|
return_timestamps=True
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
def get_supported_languages(self) -> List[str]:
|
|
|
|
|
"""Get list of supported language codes."""
|
|
|
|
|
# Parakeet TDT v3 primarily supports English
|
|
|
|
|
return ["en"]
|
|
|
|
|
|
|
|
|
|
def cleanup(self):
|
|
|
|
|
"""Cleanup resources."""
|
|
|
|
|
self.executor.shutdown(wait=True)
|
|
|
|
|
logger.info("Parakeet transcriber cleaned up")
|