Files
miku-discord/stt/stt_server.py

397 lines
13 KiB
Python

"""
STT Server
FastAPI WebSocket server for real-time speech-to-text.
Combines Silero VAD (CPU) and NVIDIA Parakeet (GPU) for efficient transcription.
Architecture:
- VAD runs continuously on every audio chunk (CPU)
- Parakeet transcribes only when VAD detects speech (GPU)
- Supports multiple concurrent users
- Sends partial and final transcripts via WebSocket with word-level tokens
"""
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
from fastapi.responses import JSONResponse
import numpy as np
import asyncio
import logging
from typing import Dict, Optional
from datetime import datetime
from vad_processor import VADProcessor
from parakeet_transcriber import ParakeetTranscriber
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='[%(levelname)s] [%(name)s] %(message)s'
)
logger = logging.getLogger('stt_server')
# Initialize FastAPI app
app = FastAPI(title="Miku STT Server", version="1.0.0")
# Global instances (initialized on startup)
vad_processor: Optional[VADProcessor] = None
parakeet_transcriber: Optional[ParakeetTranscriber] = None
# User session tracking
user_sessions: Dict[str, dict] = {}
class UserSTTSession:
"""Manages STT state for a single user."""
def __init__(self, user_id: str, websocket: WebSocket):
self.user_id = user_id
self.websocket = websocket
self.audio_buffer = []
self.is_speaking = False
self.timestamp_ms = 0.0
self.transcript_buffer = []
self.last_transcript = ""
self.last_partial_duration = 0.0 # Track when we last sent a partial
self.last_speech_timestamp = 0.0 # Track last time we detected speech
self.speech_timeout_ms = 3000 # Force finalization after 3s of no new speech
logger.info(f"Created STT session for user {user_id}")
async def process_audio_chunk(self, audio_data: bytes):
"""
Process incoming audio chunk.
Args:
audio_data: Raw PCM audio (int16, 16kHz mono)
"""
# Convert bytes to numpy array (int16)
audio_np = np.frombuffer(audio_data, dtype=np.int16)
# Calculate timestamp (assuming 16kHz, 20ms chunks = 320 samples)
chunk_duration_ms = (len(audio_np) / 16000) * 1000
self.timestamp_ms += chunk_duration_ms
# Run VAD on chunk
vad_event = vad_processor.detect_speech_segment(audio_np, self.timestamp_ms)
if vad_event:
event_type = vad_event["event"]
probability = vad_event["probability"]
logger.debug(f"VAD event for user {self.user_id}: {event_type} (prob={probability:.3f})")
# Send VAD event to client
await self.websocket.send_json({
"type": "vad",
"event": event_type,
"speaking": event_type in ["speech_start", "speaking"],
"probability": probability,
"timestamp": self.timestamp_ms
})
# Handle speech events
if event_type == "speech_start":
self.is_speaking = True
self.audio_buffer = [audio_np]
self.last_partial_duration = 0.0
self.last_speech_timestamp = self.timestamp_ms
logger.info(f"[STT] User {self.user_id} SPEECH START")
elif event_type == "speaking":
if self.is_speaking:
self.audio_buffer.append(audio_np)
self.last_speech_timestamp = self.timestamp_ms # Update speech timestamp
# Transcribe partial every ~1 second for streaming (reduced from 2s)
total_samples = sum(len(chunk) for chunk in self.audio_buffer)
duration_s = total_samples / 16000
# More frequent partials for better responsiveness
if duration_s >= 1.0:
logger.debug(f"Triggering partial transcription at {duration_s:.1f}s")
await self._transcribe_partial()
# Keep buffer for final transcription, but mark progress
self.last_partial_duration = duration_s
elif event_type == "speech_end":
self.is_speaking = False
logger.info(f"[STT] User {self.user_id} SPEECH END (VAD detected) - transcribing final")
# Transcribe final
await self._transcribe_final()
# Clear buffer
self.audio_buffer = []
self.last_partial_duration = 0.0
logger.debug(f"User {self.user_id} stopped speaking")
else:
# No VAD event - still accumulate audio if speaking
if self.is_speaking:
self.audio_buffer.append(audio_np)
# Check for timeout
time_since_speech = self.timestamp_ms - self.last_speech_timestamp
if time_since_speech >= self.speech_timeout_ms:
# Timeout - user probably stopped but VAD didn't detect it
logger.warning(f"[STT] User {self.user_id} SPEECH TIMEOUT after {time_since_speech:.0f}ms - forcing finalization")
self.is_speaking = False
# Force final transcription
await self._transcribe_final()
# Clear buffer
self.audio_buffer = []
self.last_partial_duration = 0.0
async def _transcribe_partial(self):
"""Transcribe accumulated audio and send partial result (no timestamps to save VRAM)."""
if not self.audio_buffer:
return
# Concatenate audio
audio_full = np.concatenate(self.audio_buffer)
# Transcribe asynchronously WITHOUT timestamps for partials (saves 1-2GB VRAM)
try:
result = await parakeet_transcriber.transcribe_async(
audio_full,
sample_rate=16000,
return_timestamps=False # Disable timestamps for partials to reduce VRAM usage
)
# Result is just a string when timestamps=False
text = result if isinstance(result, str) else result.get("text", "")
if text and text != self.last_transcript:
self.last_transcript = text
# Send partial transcript without word tokens (saves memory)
await self.websocket.send_json({
"type": "partial",
"text": text,
"words": [], # No word tokens for partials
"user_id": self.user_id,
"timestamp": self.timestamp_ms
})
logger.info(f"Partial [{self.user_id}]: {text}")
except Exception as e:
logger.error(f"Partial transcription failed: {e}", exc_info=True)
async def _transcribe_final(self):
"""Transcribe final accumulated audio with word tokens."""
if not self.audio_buffer:
return
# Concatenate all audio
audio_full = np.concatenate(self.audio_buffer)
try:
result = await parakeet_transcriber.transcribe_async(
audio_full,
sample_rate=16000,
return_timestamps=True
)
if result and result.get("text"):
self.last_transcript = result["text"]
# Send final transcript with word tokens
await self.websocket.send_json({
"type": "final",
"text": result["text"],
"words": result.get("words", []), # Word-level tokens for LLM
"user_id": self.user_id,
"timestamp": self.timestamp_ms
})
logger.info(f"Final [{self.user_id}]: {result['text']}")
except Exception as e:
logger.error(f"Final transcription failed: {e}", exc_info=True)
async def check_interruption(self, audio_data: bytes) -> bool:
"""
Check if user is interrupting (for use during Miku's speech).
Args:
audio_data: Raw PCM audio chunk
Returns:
True if interruption detected
"""
audio_np = np.frombuffer(audio_data, dtype=np.int16)
speech_prob, is_speaking = vad_processor.process_chunk(audio_np)
# Interruption: high probability sustained for threshold duration
if speech_prob > 0.7: # Higher threshold for interruption
await self.websocket.send_json({
"type": "interruption",
"probability": speech_prob,
"timestamp": self.timestamp_ms
})
return True
return False
@app.on_event("startup")
async def startup_event():
"""Initialize models on server startup."""
global vad_processor, parakeet_transcriber
logger.info("=" * 50)
logger.info("Initializing Miku STT Server")
logger.info("=" * 50)
# Initialize VAD (CPU)
logger.info("Loading Silero VAD model (CPU)...")
vad_processor = VADProcessor(
sample_rate=16000,
threshold=0.5,
min_speech_duration_ms=250, # Conservative - wait 250ms before starting
min_silence_duration_ms=300 # Reduced from 500ms - detect silence faster
)
logger.info("✓ VAD ready")
# Initialize Parakeet (GPU)
logger.info("Loading NVIDIA Parakeet TDT model (GPU)...")
parakeet_transcriber = ParakeetTranscriber(
model_name="nvidia/parakeet-tdt-0.6b-v3",
device="cuda",
language="en"
)
logger.info("✓ Parakeet ready")
logger.info("=" * 50)
logger.info("STT Server ready to accept connections")
logger.info("=" * 50)
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on server shutdown."""
logger.info("Shutting down STT server...")
if parakeet_transcriber:
parakeet_transcriber.cleanup()
logger.info("STT server shutdown complete")
@app.get("/")
async def root():
"""Health check endpoint."""
return {
"service": "Miku STT Server",
"status": "running",
"vad_ready": vad_processor is not None,
"whisper_ready": whisper_transcriber is not None,
"active_sessions": len(user_sessions)
}
@app.get("/health")
async def health():
"""Detailed health check."""
return {
"status": "healthy",
"models": {
"vad": {
"loaded": vad_processor is not None,
"device": "cpu"
},
"whisper": {
"loaded": whisper_transcriber is not None,
"model": "small",
"device": "cuda"
}
},
"sessions": {
"active": len(user_sessions),
"users": list(user_sessions.keys())
}
}
@app.websocket("/ws/stt/{user_id}")
async def websocket_stt(websocket: WebSocket, user_id: str):
"""
WebSocket endpoint for real-time STT.
Client sends: Raw PCM audio (int16, 16kHz mono, 20ms chunks)
Server sends: JSON events:
- {"type": "vad", "event": "speech_start|speaking|speech_end", ...}
- {"type": "partial", "text": "...", ...}
- {"type": "final", "text": "...", ...}
- {"type": "interruption", "probability": 0.xx}
"""
await websocket.accept()
logger.info(f"STT WebSocket connected: user {user_id}")
# Create session
session = UserSTTSession(user_id, websocket)
user_sessions[user_id] = session
try:
# Send ready message
await websocket.send_json({
"type": "ready",
"user_id": user_id,
"message": "STT session started"
})
# Main loop: receive audio chunks
while True:
# Receive binary audio data
data = await websocket.receive_bytes()
# Process audio chunk
await session.process_audio_chunk(data)
except WebSocketDisconnect:
logger.info(f"User {user_id} disconnected")
except Exception as e:
logger.error(f"Error in STT WebSocket for user {user_id}: {e}", exc_info=True)
finally:
# Cleanup session
if user_id in user_sessions:
del user_sessions[user_id]
logger.info(f"STT session ended for user {user_id}")
@app.post("/interrupt/check")
async def check_interruption(user_id: str):
"""
Check if user is interrupting (for use during Miku's speech).
Query param:
user_id: Discord user ID
Returns:
{"interrupting": bool, "probability": float}
"""
session = user_sessions.get(user_id)
if not session:
raise HTTPException(status_code=404, detail="User session not found")
# Get current VAD state
vad_state = vad_processor.get_state()
return {
"interrupting": vad_state["speaking"],
"user_id": user_id
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")