""" STT Server FastAPI WebSocket server for real-time speech-to-text. Combines Silero VAD (CPU) and Faster-Whisper (GPU) for efficient transcription. Architecture: - VAD runs continuously on every audio chunk (CPU) - Whisper transcribes only when VAD detects speech (GPU) - Supports multiple concurrent users - Sends partial and final transcripts via WebSocket """ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException from fastapi.responses import JSONResponse import numpy as np import asyncio import logging from typing import Dict, Optional from datetime import datetime from vad_processor import VADProcessor from whisper_transcriber import WhisperTranscriber # Configure logging logging.basicConfig( level=logging.INFO, format='[%(levelname)s] [%(name)s] %(message)s' ) logger = logging.getLogger('stt_server') # Initialize FastAPI app app = FastAPI(title="Miku STT Server", version="1.0.0") # Global instances (initialized on startup) vad_processor: Optional[VADProcessor] = None whisper_transcriber: Optional[WhisperTranscriber] = None # User session tracking user_sessions: Dict[str, dict] = {} class UserSTTSession: """Manages STT state for a single user.""" def __init__(self, user_id: str, websocket: WebSocket): self.user_id = user_id self.websocket = websocket self.audio_buffer = [] self.is_speaking = False self.timestamp_ms = 0.0 self.transcript_buffer = [] self.last_transcript = "" logger.info(f"Created STT session for user {user_id}") async def process_audio_chunk(self, audio_data: bytes): """ Process incoming audio chunk. Args: audio_data: Raw PCM audio (int16, 16kHz mono) """ # Convert bytes to numpy array (int16) audio_np = np.frombuffer(audio_data, dtype=np.int16) # Calculate timestamp (assuming 16kHz, 20ms chunks = 320 samples) chunk_duration_ms = (len(audio_np) / 16000) * 1000 self.timestamp_ms += chunk_duration_ms # Run VAD on chunk vad_event = vad_processor.detect_speech_segment(audio_np, self.timestamp_ms) if vad_event: event_type = vad_event["event"] probability = vad_event["probability"] # Send VAD event to client await self.websocket.send_json({ "type": "vad", "event": event_type, "speaking": event_type in ["speech_start", "speaking"], "probability": probability, "timestamp": self.timestamp_ms }) # Handle speech events if event_type == "speech_start": self.is_speaking = True self.audio_buffer = [audio_np] logger.debug(f"User {self.user_id} started speaking") elif event_type == "speaking": if self.is_speaking: self.audio_buffer.append(audio_np) # Transcribe partial every ~2 seconds for streaming total_samples = sum(len(chunk) for chunk in self.audio_buffer) duration_s = total_samples / 16000 if duration_s >= 2.0: await self._transcribe_partial() elif event_type == "speech_end": self.is_speaking = False # Transcribe final await self._transcribe_final() # Clear buffer self.audio_buffer = [] logger.debug(f"User {self.user_id} stopped speaking") else: # Still accumulate audio if speaking if self.is_speaking: self.audio_buffer.append(audio_np) async def _transcribe_partial(self): """Transcribe accumulated audio and send partial result.""" if not self.audio_buffer: return # Concatenate audio audio_full = np.concatenate(self.audio_buffer) # Transcribe asynchronously try: text = await whisper_transcriber.transcribe_async( audio_full, sample_rate=16000, initial_prompt=self.last_transcript # Use previous for context ) if text and text != self.last_transcript: self.last_transcript = text # Send partial transcript await self.websocket.send_json({ "type": "partial", "text": text, "user_id": self.user_id, "timestamp": self.timestamp_ms }) logger.info(f"Partial [{self.user_id}]: {text}") except Exception as e: logger.error(f"Partial transcription failed: {e}", exc_info=True) async def _transcribe_final(self): """Transcribe final accumulated audio.""" if not self.audio_buffer: return # Concatenate all audio audio_full = np.concatenate(self.audio_buffer) try: text = await whisper_transcriber.transcribe_async( audio_full, sample_rate=16000 ) if text: self.last_transcript = text # Send final transcript await self.websocket.send_json({ "type": "final", "text": text, "user_id": self.user_id, "timestamp": self.timestamp_ms }) logger.info(f"Final [{self.user_id}]: {text}") except Exception as e: logger.error(f"Final transcription failed: {e}", exc_info=True) async def check_interruption(self, audio_data: bytes) -> bool: """ Check if user is interrupting (for use during Miku's speech). Args: audio_data: Raw PCM audio chunk Returns: True if interruption detected """ audio_np = np.frombuffer(audio_data, dtype=np.int16) speech_prob, is_speaking = vad_processor.process_chunk(audio_np) # Interruption: high probability sustained for threshold duration if speech_prob > 0.7: # Higher threshold for interruption await self.websocket.send_json({ "type": "interruption", "probability": speech_prob, "timestamp": self.timestamp_ms }) return True return False @app.on_event("startup") async def startup_event(): """Initialize models on server startup.""" global vad_processor, whisper_transcriber logger.info("=" * 50) logger.info("Initializing Miku STT Server") logger.info("=" * 50) # Initialize VAD (CPU) logger.info("Loading Silero VAD model (CPU)...") vad_processor = VADProcessor( sample_rate=16000, threshold=0.5, min_speech_duration_ms=250, # Conservative min_silence_duration_ms=500 # Conservative ) logger.info("✓ VAD ready") # Initialize Whisper (GPU with cuDNN) logger.info("Loading Faster-Whisper model (GPU)...") whisper_transcriber = WhisperTranscriber( model_size="small", device="cuda", compute_type="float16", language="en" ) logger.info("✓ Whisper ready") logger.info("=" * 50) logger.info("STT Server ready to accept connections") logger.info("=" * 50) @app.on_event("shutdown") async def shutdown_event(): """Cleanup on server shutdown.""" logger.info("Shutting down STT server...") if whisper_transcriber: whisper_transcriber.cleanup() logger.info("STT server shutdown complete") @app.get("/") async def root(): """Health check endpoint.""" return { "service": "Miku STT Server", "status": "running", "vad_ready": vad_processor is not None, "whisper_ready": whisper_transcriber is not None, "active_sessions": len(user_sessions) } @app.get("/health") async def health(): """Detailed health check.""" return { "status": "healthy", "models": { "vad": { "loaded": vad_processor is not None, "device": "cpu" }, "whisper": { "loaded": whisper_transcriber is not None, "model": "small", "device": "cuda" } }, "sessions": { "active": len(user_sessions), "users": list(user_sessions.keys()) } } @app.websocket("/ws/stt/{user_id}") async def websocket_stt(websocket: WebSocket, user_id: str): """ WebSocket endpoint for real-time STT. Client sends: Raw PCM audio (int16, 16kHz mono, 20ms chunks) Server sends: JSON events: - {"type": "vad", "event": "speech_start|speaking|speech_end", ...} - {"type": "partial", "text": "...", ...} - {"type": "final", "text": "...", ...} - {"type": "interruption", "probability": 0.xx} """ await websocket.accept() logger.info(f"STT WebSocket connected: user {user_id}") # Create session session = UserSTTSession(user_id, websocket) user_sessions[user_id] = session try: # Send ready message await websocket.send_json({ "type": "ready", "user_id": user_id, "message": "STT session started" }) # Main loop: receive audio chunks while True: # Receive binary audio data data = await websocket.receive_bytes() # Process audio chunk await session.process_audio_chunk(data) except WebSocketDisconnect: logger.info(f"User {user_id} disconnected") except Exception as e: logger.error(f"Error in STT WebSocket for user {user_id}: {e}", exc_info=True) finally: # Cleanup session if user_id in user_sessions: del user_sessions[user_id] logger.info(f"STT session ended for user {user_id}") @app.post("/interrupt/check") async def check_interruption(user_id: str): """ Check if user is interrupting (for use during Miku's speech). Query param: user_id: Discord user ID Returns: {"interrupting": bool, "probability": float} """ session = user_sessions.get(user_id) if not session: raise HTTPException(status_code=404, detail="User session not found") # Get current VAD state vad_state = vad_processor.get_state() return { "interrupting": vad_state["speaking"], "user_id": user_id } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")