#!/usr/bin/env python3 """ Low-Latency STT WebSocket Server Uses Silero VAD for speech detection + Faster-Whisper turbo for transcription. Achieves sub-second latency after speech ends. Architecture: 1. Silero VAD runs on every audio chunk to detect speech boundaries 2. When speech ends (silence detected), immediately transcribe the buffer 3. Send final transcript - no waiting for stability Protocol: - Client sends: binary audio data (16kHz, 16-bit mono PCM) - Client sends: JSON {"command": "reset"} to reset state - Server sends: JSON {"type": "partial", "text": "...", "timestamp": float} - Server sends: JSON {"type": "final", "text": "...", "timestamp": float} """ import asyncio import json import logging import time import threading import queue from typing import Optional, Dict, Any import numpy as np import websockets from websockets.server import serve from aiohttp import web # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s %(levelname)s [%(name)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) logger = logging.getLogger('stt-realtime') # Silero VAD import torch torch.set_num_threads(1) # Prevent thread contention # Faster-Whisper for transcription from faster_whisper import WhisperModel # Global model (shared across sessions for memory efficiency) whisper_model: Optional[WhisperModel] = None vad_model = None warmup_complete = False def load_vad_model(): """Load Silero VAD model.""" global vad_model model, _ = torch.hub.load( repo_or_dir='snakers4/silero-vad', model='silero_vad', force_reload=False, onnx=True # Use ONNX for speed ) vad_model = model logger.info("Silero VAD loaded (ONNX)") return model def load_whisper_model(config: Dict[str, Any]): """Load Faster-Whisper model.""" global whisper_model whisper_model = WhisperModel( config['model'], device=config['device'], compute_type=config['compute_type'], ) logger.info(f"Faster-Whisper '{config['model']}' loaded on {config['device']}") return whisper_model class STTSession: """ Low-latency STT session using Silero VAD + Faster-Whisper. """ SAMPLE_RATE = 16000 VAD_CHUNK_MS = 32 # Silero needs 512 samples at 16kHz = 32ms VAD_CHUNK_SAMPLES = 512 # Fixed: Silero requires exactly 512 samples at 16kHz def __init__(self, websocket, session_id: str, config: Dict[str, Any]): self.websocket = websocket self.session_id = session_id self.config = config self.running = False self.loop = None # Audio state self.audio_buffer = [] # Float32 samples for current utterance self.vad_buffer = [] # Small buffer for VAD chunk alignment # Speech detection state self.is_speaking = False self.silence_start_time = 0 self.speech_start_time = 0 # Configurable thresholds self.vad_threshold = config.get('vad_threshold', 0.5) self.silence_duration_ms = config.get('silence_duration_ms', 400) self.min_speech_ms = config.get('min_speech_ms', 250) self.max_speech_duration = config.get('max_speech_duration', 30.0) # Speculative transcription settings self.speculative_silence_ms = config.get('speculative_silence_ms', 150) # Start transcribing early self.speculative_pending = False # Is a speculative transcription in flight? self.speculative_audio_snapshot = None # Audio buffer snapshot for speculative self.speculative_result = None # Result from speculative transcription self.speculative_result_ready = threading.Event() # Transcription queue self.transcribe_queue = queue.Queue() self.transcribe_thread = None logger.info(f"[{session_id}] Session created (speculative: {self.speculative_silence_ms}ms, final: {self.silence_duration_ms}ms)") async def start(self, loop: asyncio.AbstractEventLoop): """Start the session.""" self.loop = loop self.running = True self.transcribe_thread = threading.Thread(target=self._transcription_worker, daemon=True) self.transcribe_thread.start() logger.info(f"[{self.session_id}] Session started") def _transcription_worker(self): """Background thread that processes transcription requests.""" while self.running: try: item = self.transcribe_queue.get(timeout=0.1) if item is None: continue audio_array, is_final, is_speculative = item start_time = time.time() segments, info = whisper_model.transcribe( audio_array, language=self.config.get('language', 'en'), beam_size=1, best_of=1, temperature=0.0, vad_filter=False, without_timestamps=True, ) text = " ".join(seg.text for seg in segments).strip() elapsed = time.time() - start_time if is_speculative: # Store result for potential use self.speculative_result = (text, elapsed) self.speculative_result_ready.set() logger.debug(f"[{self.session_id}] SPECULATIVE ({elapsed:.2f}s): {text}") elif text: transcript_type = "final" if is_final else "partial" logger.info(f"[{self.session_id}] {transcript_type.upper()} ({elapsed:.2f}s): {text}") asyncio.run_coroutine_threadsafe( self._send_transcript(transcript_type, text), self.loop ) except queue.Empty: continue except Exception as e: logger.error(f"[{self.session_id}] Transcription error: {e}", exc_info=True) async def _send_transcript(self, transcript_type: str, text: str): """Send transcript to client.""" try: await self.websocket.send(json.dumps({ "type": transcript_type, "text": text, "timestamp": time.time() })) except Exception as e: logger.error(f"[{self.session_id}] Send error: {e}") def feed_audio(self, audio_data: bytes): """Process incoming audio data.""" if not self.running: return audio_int16 = np.frombuffer(audio_data, dtype=np.int16) audio_float = audio_int16.astype(np.float32) / 32768.0 self.vad_buffer.extend(audio_float) while len(self.vad_buffer) >= self.VAD_CHUNK_SAMPLES: chunk = np.array(self.vad_buffer[:self.VAD_CHUNK_SAMPLES], dtype=np.float32) self.vad_buffer = self.vad_buffer[self.VAD_CHUNK_SAMPLES:] self._process_vad_chunk(chunk) def _process_vad_chunk(self, chunk: np.ndarray): """Process a single VAD chunk.""" current_time = time.time() chunk_tensor = torch.from_numpy(chunk) speech_prob = vad_model(chunk_tensor, self.SAMPLE_RATE).item() is_speech = speech_prob >= self.vad_threshold if is_speech: if not self.is_speaking: self.is_speaking = True self.speech_start_time = current_time self.audio_buffer = [] logger.debug(f"[{self.session_id}] Speech started") self.audio_buffer.extend(chunk) self.silence_start_time = 0 # Cancel any speculative transcription if speech resumed if self.speculative_pending: logger.debug(f"[{self.session_id}] Speech resumed, canceling speculative") self.speculative_pending = False self.speculative_result = None self.speculative_result_ready.clear() speech_duration = current_time - self.speech_start_time if speech_duration >= self.max_speech_duration: logger.info(f"[{self.session_id}] Max duration reached") self._finalize_utterance() else: if self.is_speaking: self.audio_buffer.extend(chunk) if self.silence_start_time == 0: self.silence_start_time = current_time silence_duration_ms = (current_time - self.silence_start_time) * 1000 speech_duration_ms = (self.silence_start_time - self.speech_start_time) * 1000 # Trigger speculative transcription early if (not self.speculative_pending and silence_duration_ms >= self.speculative_silence_ms and speech_duration_ms >= self.min_speech_ms): self._start_speculative_transcription() # Final silence threshold reached if silence_duration_ms >= self.silence_duration_ms: if speech_duration_ms >= self.min_speech_ms: logger.debug(f"[{self.session_id}] Speech ended ({speech_duration_ms:.0f}ms)") self._finalize_utterance() else: logger.debug(f"[{self.session_id}] Discarding short utterance") self._reset_state() def _start_speculative_transcription(self): """Start speculative transcription without waiting for full silence.""" if self.audio_buffer: self.speculative_pending = True self.speculative_result = None self.speculative_result_ready.clear() # Snapshot current buffer audio_array = np.array(self.audio_buffer, dtype=np.float32) duration = len(audio_array) / self.SAMPLE_RATE logger.debug(f"[{self.session_id}] Starting speculative transcription ({duration:.1f}s)") # is_speculative=True self.transcribe_queue.put((audio_array, False, True)) def _finalize_utterance(self): """Finalize current utterance and send transcript.""" if not self.audio_buffer: self._reset_state() return audio_array = np.array(self.audio_buffer, dtype=np.float32) duration = len(audio_array) / self.SAMPLE_RATE # Check if we have a speculative result ready if self.speculative_pending and self.speculative_result_ready.wait(timeout=0.05): # Use speculative result immediately! text, elapsed = self.speculative_result if text: logger.info(f"[{self.session_id}] FINAL [speculative] ({elapsed:.2f}s): {text}") asyncio.run_coroutine_threadsafe( self._send_transcript("final", text), self.loop ) self._reset_state() return # No speculative result, do regular transcription logger.info(f"[{self.session_id}] Queuing transcription ({duration:.1f}s)") self.transcribe_queue.put((audio_array, True, False)) self._reset_state() def _reset_state(self): """Reset speech detection state.""" self.is_speaking = False self.audio_buffer = [] self.silence_start_time = 0 self.speech_start_time = 0 self.speculative_pending = False self.speculative_result = None self.speculative_result_ready.clear() def reset(self): """Reset session state.""" logger.info(f"[{self.session_id}] Resetting") self._reset_state() self.vad_buffer = [] async def stop(self): """Stop the session.""" logger.info(f"[{self.session_id}] Stopping...") self.running = False if self.audio_buffer and self.is_speaking: self._finalize_utterance() if self.transcribe_thread and self.transcribe_thread.is_alive(): self.transcribe_thread.join(timeout=2) logger.info(f"[{self.session_id}] Stopped") class STTServer: """WebSocket server for low-latency STT.""" def __init__(self, host: str, port: int, config: Dict[str, Any]): self.host = host self.port = port self.config = config self.sessions: Dict[str, STTSession] = {} self.session_counter = 0 logger.info("=" * 60) logger.info("Low-Latency STT Server") logger.info(f" Host: {host}:{port}") logger.info(f" Model: {config['model']}") logger.info(f" Language: {config.get('language', 'en')}") logger.info(f" Silence: {config.get('silence_duration_ms', 400)}ms") logger.info("=" * 60) async def handle_client(self, websocket): """Handle WebSocket client.""" self.session_counter += 1 session_id = f"session_{self.session_counter}" session = None try: logger.info(f"[{session_id}] Client connected") session = STTSession(websocket, session_id, self.config) self.sessions[session_id] = session await session.start(asyncio.get_event_loop()) async for message in websocket: if isinstance(message, bytes): session.feed_audio(message) else: try: data = json.loads(message) cmd = data.get('command', '') if cmd == 'reset': session.reset() elif cmd == 'ping': await websocket.send(json.dumps({ 'type': 'pong', 'timestamp': time.time() })) except json.JSONDecodeError: pass except websockets.exceptions.ConnectionClosed: logger.info(f"[{session_id}] Client disconnected") except Exception as e: logger.error(f"[{session_id}] Error: {e}", exc_info=True) finally: if session: await session.stop() del self.sessions[session_id] async def run(self): """Run the server.""" logger.info(f"Starting server on ws://{self.host}:{self.port}") async with serve( self.handle_client, self.host, self.port, ping_interval=30, ping_timeout=10, max_size=10 * 1024 * 1024, ): logger.info("Server ready") await asyncio.Future() async def warmup(config: Dict[str, Any]): """Load models at startup.""" global warmup_complete logger.info("Loading models...") load_vad_model() load_whisper_model(config) logger.info("Warming up transcription...") dummy_audio = np.zeros(16000, dtype=np.float32) segments, _ = whisper_model.transcribe( dummy_audio, language=config.get('language', 'en'), beam_size=1, ) list(segments) warmup_complete = True logger.info("Warmup complete") async def health_handler(request): """Health check endpoint.""" if warmup_complete: return web.json_response({"status": "ready"}) return web.json_response({"status": "warming_up"}, status=503) async def start_http_server(host: str, port: int): """Start HTTP health server.""" app = web.Application() app.router.add_get('/health', health_handler) runner = web.AppRunner(app) await runner.setup() site = web.TCPSite(runner, host, port) await site.start() logger.info(f"Health server on http://{host}:{port}") def main(): """Main entry point.""" import os host = os.environ.get('STT_HOST', '0.0.0.0') port = int(os.environ.get('STT_PORT', '8766')) http_port = int(os.environ.get('STT_HTTP_PORT', '8767')) config = { 'model': 'small.en', 'language': 'en', 'compute_type': 'float16', 'device': 'cuda', 'vad_threshold': 0.5, 'silence_duration_ms': 400, # Final silence threshold 'speculative_silence_ms': 150, # Start transcribing early at 150ms 'min_speech_ms': 250, 'max_speech_duration': 30.0, } server = STTServer(host, port, config) async def run_all(): await warmup(config) asyncio.create_task(start_http_server(host, http_port)) await server.run() try: asyncio.run(run_all()) except KeyboardInterrupt: logger.info("Shutdown requested") except Exception as e: logger.error(f"Server error: {e}", exc_info=True) raise if __name__ == '__main__': main()