Changed stt to parakeet — still experiemntal, though performance seems to be better

This commit is contained in:
2026-01-18 03:35:50 +02:00
parent 50e4f7a5f2
commit 0a8910fff8
10 changed files with 375 additions and 37 deletions

View File

@@ -2,13 +2,13 @@
STT Server
FastAPI WebSocket server for real-time speech-to-text.
Combines Silero VAD (CPU) and Faster-Whisper (GPU) for efficient transcription.
Combines Silero VAD (CPU) and NVIDIA Parakeet (GPU) for efficient transcription.
Architecture:
- VAD runs continuously on every audio chunk (CPU)
- Whisper transcribes only when VAD detects speech (GPU)
- Parakeet transcribes only when VAD detects speech (GPU)
- Supports multiple concurrent users
- Sends partial and final transcripts via WebSocket
- Sends partial and final transcripts via WebSocket with word-level tokens
"""
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
@@ -20,7 +20,7 @@ from typing import Dict, Optional
from datetime import datetime
from vad_processor import VADProcessor
from whisper_transcriber import WhisperTranscriber
from parakeet_transcriber import ParakeetTranscriber
# Configure logging
logging.basicConfig(
@@ -34,7 +34,7 @@ app = FastAPI(title="Miku STT Server", version="1.0.0")
# Global instances (initialized on startup)
vad_processor: Optional[VADProcessor] = None
whisper_transcriber: Optional[WhisperTranscriber] = None
parakeet_transcriber: Optional[ParakeetTranscriber] = None
# User session tracking
user_sessions: Dict[str, dict] = {}
@@ -117,39 +117,40 @@ class UserSTTSession:
self.audio_buffer.append(audio_np)
async def _transcribe_partial(self):
"""Transcribe accumulated audio and send partial result."""
"""Transcribe accumulated audio and send partial result with word tokens."""
if not self.audio_buffer:
return
# Concatenate audio
audio_full = np.concatenate(self.audio_buffer)
# Transcribe asynchronously
# Transcribe asynchronously with word-level timestamps
try:
text = await whisper_transcriber.transcribe_async(
result = await parakeet_transcriber.transcribe_async(
audio_full,
sample_rate=16000,
initial_prompt=self.last_transcript # Use previous for context
return_timestamps=True
)
if text and text != self.last_transcript:
self.last_transcript = text
if result and result.get("text") and result["text"] != self.last_transcript:
self.last_transcript = result["text"]
# Send partial transcript
# Send partial transcript with word tokens for LLM pre-computation
await self.websocket.send_json({
"type": "partial",
"text": text,
"text": result["text"],
"words": result.get("words", []), # Word-level tokens
"user_id": self.user_id,
"timestamp": self.timestamp_ms
})
logger.info(f"Partial [{self.user_id}]: {text}")
logger.info(f"Partial [{self.user_id}]: {result['text']}")
except Exception as e:
logger.error(f"Partial transcription failed: {e}", exc_info=True)
async def _transcribe_final(self):
"""Transcribe final accumulated audio."""
"""Transcribe final accumulated audio with word tokens."""
if not self.audio_buffer:
return
@@ -157,23 +158,25 @@ class UserSTTSession:
audio_full = np.concatenate(self.audio_buffer)
try:
text = await whisper_transcriber.transcribe_async(
result = await parakeet_transcriber.transcribe_async(
audio_full,
sample_rate=16000
sample_rate=16000,
return_timestamps=True
)
if text:
self.last_transcript = text
if result and result.get("text"):
self.last_transcript = result["text"]
# Send final transcript
# Send final transcript with word tokens
await self.websocket.send_json({
"type": "final",
"text": text,
"text": result["text"],
"words": result.get("words", []), # Word-level tokens for LLM
"user_id": self.user_id,
"timestamp": self.timestamp_ms
})
logger.info(f"Final [{self.user_id}]: {text}")
logger.info(f"Final [{self.user_id}]: {result['text']}")
except Exception as e:
logger.error(f"Final transcription failed: {e}", exc_info=True)
@@ -206,7 +209,7 @@ class UserSTTSession:
@app.on_event("startup")
async def startup_event():
"""Initialize models on server startup."""
global vad_processor, whisper_transcriber
global vad_processor, parakeet_transcriber
logger.info("=" * 50)
logger.info("Initializing Miku STT Server")
@@ -222,15 +225,14 @@ async def startup_event():
)
logger.info("✓ VAD ready")
# Initialize Whisper (GPU with cuDNN)
logger.info("Loading Faster-Whisper model (GPU)...")
whisper_transcriber = WhisperTranscriber(
model_size="small",
# Initialize Parakeet (GPU)
logger.info("Loading NVIDIA Parakeet TDT model (GPU)...")
parakeet_transcriber = ParakeetTranscriber(
model_name="nvidia/parakeet-tdt-0.6b-v3",
device="cuda",
compute_type="float16",
language="en"
)
logger.info("Whisper ready")
logger.info("Parakeet ready")
logger.info("=" * 50)
logger.info("STT Server ready to accept connections")
@@ -242,8 +244,8 @@ async def shutdown_event():
"""Cleanup on server shutdown."""
logger.info("Shutting down STT server...")
if whisper_transcriber:
whisper_transcriber.cleanup()
if parakeet_transcriber:
parakeet_transcriber.cleanup()
logger.info("STT server shutdown complete")