Files
miku-discord/stt/parakeet_transcriber.py

230 lines
7.9 KiB
Python
Raw Normal View History

"""
NVIDIA Parakeet TDT Transcriber
Real-time streaming ASR using NVIDIA's Parakeet TDT (Token-and-Duration Transducer) model.
Supports streaming transcription with word-level timestamps for LLM pre-computation.
Model: nvidia/parakeet-tdt-0.6b-v3
- 600M parameters
- Real-time capable on GPU
- Word-level timestamps
- Streaming support via NVIDIA NeMo
"""
import numpy as np
import torch
from nemo.collections.asr.models import EncDecRNNTBPEModel
from typing import Optional, List, Dict
import logging
import asyncio
from concurrent.futures import ThreadPoolExecutor
logger = logging.getLogger('parakeet')
class ParakeetTranscriber:
"""
NVIDIA Parakeet-based streaming transcription with word-level tokens.
Uses NVIDIA NeMo for proper model loading and inference.
"""
def __init__(
self,
model_name: str = "nvidia/parakeet-tdt-0.6b-v3",
device: str = "cuda",
language: str = "en"
):
"""
Initialize Parakeet transcriber.
Args:
model_name: HuggingFace model identifier
device: Device to run on (cuda or cpu)
language: Language code (Parakeet primarily supports English)
"""
self.model_name = model_name
self.device = device
self.language = language
logger.info(f"Loading Parakeet model: {model_name} on {device}...")
# Set PyTorch memory allocator settings for better memory management
if device == "cuda":
# Enable expandable segments to reduce fragmentation
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
# Clear cache before loading model
torch.cuda.empty_cache()
# Load model via NeMo from HuggingFace
self.model = EncDecRNNTBPEModel.from_pretrained(
model_name=model_name,
map_location=device
)
self.model.eval()
if device == "cuda":
self.model = self.model.cuda()
# Enable memory efficient attention if available
try:
self.model.encoder.use_memory_efficient_attention = True
except:
pass
# Thread pool for blocking transcription calls
self.executor = ThreadPoolExecutor(max_workers=2)
logger.info(f"Parakeet model loaded on {device}")
async def transcribe_async(
self,
audio: np.ndarray,
sample_rate: int = 16000,
return_timestamps: bool = False
) -> str:
"""
Transcribe audio asynchronously (non-blocking).
Args:
audio: Audio data as numpy array (float32)
sample_rate: Audio sample rate (Parakeet expects 16kHz)
return_timestamps: Whether to return word-level timestamps
Returns:
Transcribed text (or dict with timestamps if return_timestamps=True)
"""
loop = asyncio.get_event_loop()
# Run transcription in thread pool to avoid blocking
result = await loop.run_in_executor(
self.executor,
self._transcribe_blocking,
audio,
sample_rate,
return_timestamps
)
return result
def _transcribe_blocking(
self,
audio: np.ndarray,
sample_rate: int,
return_timestamps: bool
):
"""
Blocking transcription call (runs in thread pool).
"""
# Convert to float32 if needed
if audio.dtype != np.float32:
audio = audio.astype(np.float32) / 32768.0
# Ensure correct sample rate (Parakeet expects 16kHz)
if sample_rate != 16000:
logger.warning(f"Audio sample rate is {sample_rate}Hz, Parakeet expects 16kHz. Resampling...")
import torchaudio
audio_tensor = torch.from_numpy(audio).unsqueeze(0)
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
audio_tensor = resampler(audio_tensor)
audio = audio_tensor.squeeze(0).numpy()
sample_rate = 16000
# Transcribe using NeMo model
with torch.no_grad():
# Convert to tensor and keep on GPU to avoid CPU/GPU bouncing
audio_signal = torch.from_numpy(audio).unsqueeze(0)
audio_signal_len = torch.tensor([len(audio)])
if self.device == "cuda":
audio_signal = audio_signal.cuda()
audio_signal_len = audio_signal_len.cuda()
# Get transcription
# NeMo returns list of Hypothesis objects
# Note: timestamps=True causes significant VRAM usage (~1-2GB extra)
# Only enable for final transcriptions, not streaming partials
transcriptions = self.model.transcribe(
audio=[audio], # Pass NumPy array directly (NeMo handles it efficiently)
batch_size=1,
timestamps=return_timestamps # Only use timestamps when explicitly requested
)
# Extract text from Hypothesis object
hypothesis = transcriptions[0] if transcriptions else None
if hypothesis is None:
text = ""
words = []
else:
# Hypothesis object has .text attribute
text = hypothesis.text.strip() if hasattr(hypothesis, 'text') else str(hypothesis).strip()
# Extract word-level timestamps if available and requested
words = []
if return_timestamps and hasattr(hypothesis, 'timestamp') and hypothesis.timestamp:
# timestamp is a dict with 'word' key containing list of word timestamps
word_timestamps = hypothesis.timestamp.get('word', [])
for word_info in word_timestamps:
words.append({
"word": word_info.get('word', ''),
"start_time": word_info.get('start', 0.0),
"end_time": word_info.get('end', 0.0)
})
logger.debug(f"Transcribed: '{text}' with {len(words)} words")
if return_timestamps:
return {
"text": text,
"words": words
}
else:
return text
# Note: We do NOT call torch.cuda.empty_cache() here
# That breaks PyTorch's memory allocator and causes fragmentation
# Let PyTorch manage its own memory pool
async def transcribe_streaming(
self,
audio_chunks: List[np.ndarray],
sample_rate: int = 16000,
chunk_size_ms: int = 500
) -> Dict[str, any]:
"""
Transcribe audio chunks with streaming support.
Args:
audio_chunks: List of audio chunks to process
sample_rate: Audio sample rate
chunk_size_ms: Size of each chunk in milliseconds
Returns:
Dict with partial and word-level results
"""
if not audio_chunks:
return {"text": "", "words": []}
# Concatenate all chunks
audio_data = np.concatenate(audio_chunks)
# Transcribe with timestamps for streaming
result = await self.transcribe_async(
audio_data,
sample_rate,
return_timestamps=True
)
return result
def get_supported_languages(self) -> List[str]:
"""Get list of supported language codes."""
# Parakeet TDT v3 primarily supports English
return ["en"]
def cleanup(self):
"""Cleanup resources."""
self.executor.shutdown(wait=True)
logger.info("Parakeet transcriber cleaned up")