#!/usr/bin/env python3 """ ASR WebSocket Server with VAD - Optimized for Discord Bots This server uses Voice Activity Detection (VAD) to: - Detect speech start and end automatically - Only transcribe speech segments (ignore silence) - Provide clean boundaries for Discord message formatting - Minimize processing of silence/noise """ import asyncio import websockets import numpy as np import json import logging import sys from datetime import datetime from pathlib import Path from collections import deque from dataclasses import dataclass from typing import Optional # Add project root to path sys.path.insert(0, str(Path(__file__).parent.parent)) from asr.asr_pipeline import ASRPipeline # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('vad_server.log'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) @dataclass class SpeechSegment: """Represents a segment of detected speech.""" audio: np.ndarray start_time: float end_time: Optional[float] = None is_complete: bool = False class VADState: """Manages VAD state for speech detection.""" def __init__(self, sample_rate: int = 16000, speech_threshold: float = 0.5): self.sample_rate = sample_rate # Simple energy-based VAD parameters self.energy_threshold = 0.005 # Lower threshold for better detection self.speech_frames = 0 self.silence_frames = 0 self.min_speech_frames = 3 # 3 frames minimum (300ms with 100ms chunks) self.min_silence_frames = 5 # 5 frames of silence (500ms) self.is_speech = False self.speech_buffer = [] # Pre-buffer to capture audio BEFORE speech detection # This prevents cutting off the start of speech self.pre_buffer_frames = 5 # Keep 5 frames (500ms) of pre-speech audio self.pre_buffer = deque(maxlen=self.pre_buffer_frames) # Progressive transcription tracking self.last_partial_samples = 0 # Track when we last sent a partial self.partial_interval_samples = int(sample_rate * 0.3) # Partial every 0.3 seconds (near real-time) logger.info(f"VAD initialized: energy_threshold={self.energy_threshold}, pre_buffer={self.pre_buffer_frames} frames") def calculate_energy(self, audio_chunk: np.ndarray) -> float: """Calculate RMS energy of audio chunk.""" return np.sqrt(np.mean(audio_chunk ** 2)) def process_audio(self, audio_chunk: np.ndarray) -> tuple[bool, Optional[np.ndarray], Optional[np.ndarray]]: """ Process audio chunk and detect speech boundaries. Returns: (speech_detected, complete_segment, partial_segment) - speech_detected: True if currently in speech - complete_segment: Audio segment if speech ended, None otherwise - partial_segment: Audio for partial transcription, None otherwise """ energy = self.calculate_energy(audio_chunk) chunk_is_speech = energy > self.energy_threshold logger.debug(f"Energy: {energy:.6f}, Is speech: {chunk_is_speech}") partial_segment = None if chunk_is_speech: self.speech_frames += 1 self.silence_frames = 0 if not self.is_speech and self.speech_frames >= self.min_speech_frames: # Speech started - add pre-buffer to capture the beginning! self.is_speech = True logger.info("🎤 Speech started (including pre-buffer)") # Add pre-buffered audio to speech buffer if self.pre_buffer: logger.debug(f"Adding {len(self.pre_buffer)} pre-buffered frames") self.speech_buffer.extend(list(self.pre_buffer)) self.pre_buffer.clear() if self.is_speech: self.speech_buffer.append(audio_chunk) else: # Not in speech yet, keep in pre-buffer self.pre_buffer.append(audio_chunk) # Check if we should send a partial transcription current_samples = sum(len(chunk) for chunk in self.speech_buffer) samples_since_last_partial = current_samples - self.last_partial_samples # Send partial if enough NEW audio accumulated AND we have minimum duration min_duration_for_partial = int(self.sample_rate * 0.8) # At least 0.8s of audio if samples_since_last_partial >= self.partial_interval_samples and current_samples >= min_duration_for_partial: # Time for a partial update partial_segment = np.concatenate(self.speech_buffer) self.last_partial_samples = current_samples logger.debug(f"📝 Partial update: {current_samples/self.sample_rate:.2f}s") else: if self.is_speech: self.silence_frames += 1 # Add some trailing silence (up to limit) if self.silence_frames < self.min_silence_frames: self.speech_buffer.append(audio_chunk) else: # Speech ended logger.info(f"🛑 Speech ended after {self.silence_frames} silence frames") self.is_speech = False self.speech_frames = 0 self.silence_frames = 0 self.last_partial_samples = 0 # Reset partial counter if self.speech_buffer: complete_segment = np.concatenate(self.speech_buffer) segment_duration = len(complete_segment) / self.sample_rate self.speech_buffer = [] self.pre_buffer.clear() # Clear pre-buffer after speech ends logger.info(f"✅ Complete segment: {segment_duration:.2f}s") return False, complete_segment, None else: self.speech_frames = 0 # Keep adding to pre-buffer when not in speech self.pre_buffer.append(audio_chunk) return self.is_speech, None, partial_segment class VADServer: """ WebSocket server with VAD for Discord bot integration. """ def __init__( self, host: str = "0.0.0.0", port: int = 8766, model_path: str = "models/parakeet", sample_rate: int = 16000, ): """Initialize server.""" self.host = host self.port = port self.sample_rate = sample_rate self.active_connections = set() # Terminal control codes self.BOLD = '\033[1m' self.GREEN = '\033[92m' self.YELLOW = '\033[93m' self.BLUE = '\033[94m' self.RED = '\033[91m' self.RESET = '\033[0m' # Initialize ASR pipeline logger.info("Loading ASR model...") self.pipeline = ASRPipeline(model_path=model_path) logger.info("ASR Pipeline ready") def print_header(self): """Print server header.""" print("\n" + "=" * 80) print(f"{self.BOLD}{self.BLUE}ASR Server with VAD - Discord Bot Ready{self.RESET}") print("=" * 80) print(f"Server: ws://{self.host}:{self.port}") print(f"Sample Rate: {self.sample_rate} Hz") print(f"Model: Parakeet TDT 0.6B V3") print(f"VAD: Energy-based speech detection") print("=" * 80 + "\n") def display_transcription(self, client_id: str, text: str, duration: float): """Display transcription in the terminal.""" timestamp = datetime.now().strftime("%H:%M:%S") print(f"{self.GREEN}{self.BOLD}[{timestamp}] {client_id}{self.RESET}") print(f"{self.GREEN} 📝 {text}{self.RESET}") print(f"{self.BLUE} ⏱️ Duration: {duration:.2f}s{self.RESET}\n") sys.stdout.flush() async def handle_client(self, websocket): """Handle WebSocket client connection.""" client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}" logger.info(f"Client connected: {client_id}") self.active_connections.add(websocket) print(f"\n{self.BOLD}{'='*80}{self.RESET}") print(f"{self.GREEN}✓ Client connected: {client_id}{self.RESET}") print(f"{self.BOLD}{'='*80}{self.RESET}\n") sys.stdout.flush() # Initialize VAD state for this client vad_state = VADState(sample_rate=self.sample_rate) try: # Send welcome message await websocket.send(json.dumps({ "type": "info", "message": "Connected to ASR server with VAD", "sample_rate": self.sample_rate, "vad_enabled": True, })) async for message in websocket: try: if isinstance(message, bytes): # Binary audio data audio_data = np.frombuffer(message, dtype=np.int16) audio_data = audio_data.astype(np.float32) / 32768.0 # Process through VAD is_speech, complete_segment, partial_segment = vad_state.process_audio(audio_data) # Send VAD status to client (only on state change) prev_speech_state = getattr(vad_state, '_prev_speech_state', False) if is_speech != prev_speech_state: vad_state._prev_speech_state = is_speech await websocket.send(json.dumps({ "type": "vad_status", "is_speech": is_speech, })) # Handle partial transcription (progressive updates while speaking) if partial_segment is not None: try: text = self.pipeline.transcribe( partial_segment, sample_rate=self.sample_rate ) if text and text.strip(): duration = len(partial_segment) / self.sample_rate # Display on server timestamp = datetime.now().strftime("%H:%M:%S") print(f"{self.YELLOW}[{timestamp}] {client_id}{self.RESET}") print(f"{self.YELLOW} → PARTIAL: {text}{self.RESET}\n") sys.stdout.flush() # Send to client response = { "type": "transcript", "text": text, "is_final": False, "duration": duration, } await websocket.send(json.dumps(response)) except Exception as e: logger.error(f"Partial transcription error: {e}") # If we have a complete speech segment, transcribe it if complete_segment is not None: try: text = self.pipeline.transcribe( complete_segment, sample_rate=self.sample_rate ) if text and text.strip(): duration = len(complete_segment) / self.sample_rate # Display on server self.display_transcription(client_id, text, duration) # Send to client response = { "type": "transcript", "text": text, "is_final": True, "duration": duration, } await websocket.send(json.dumps(response)) except Exception as e: logger.error(f"Transcription error: {e}") await websocket.send(json.dumps({ "type": "error", "message": f"Transcription failed: {str(e)}" })) elif isinstance(message, str): # JSON command try: command = json.loads(message) if command.get("type") == "force_transcribe": # Force transcribe current buffer if vad_state.speech_buffer: audio_chunk = np.concatenate(vad_state.speech_buffer) vad_state.speech_buffer = [] vad_state.is_speech = False text = self.pipeline.transcribe( audio_chunk, sample_rate=self.sample_rate ) if text and text.strip(): duration = len(audio_chunk) / self.sample_rate self.display_transcription(client_id, text, duration) response = { "type": "transcript", "text": text, "is_final": True, "duration": duration, } await websocket.send(json.dumps(response)) elif command.get("type") == "reset": # Reset VAD state vad_state = VADState(sample_rate=self.sample_rate) await websocket.send(json.dumps({ "type": "info", "message": "VAD state reset" })) print(f"{self.YELLOW}[{client_id}] VAD reset{self.RESET}\n") sys.stdout.flush() elif command.get("type") == "set_threshold": # Adjust VAD threshold threshold = command.get("threshold", 0.01) vad_state.energy_threshold = threshold await websocket.send(json.dumps({ "type": "info", "message": f"VAD threshold set to {threshold}" })) except json.JSONDecodeError: logger.warning(f"Invalid JSON from {client_id}: {message}") except Exception as e: logger.error(f"Error processing message from {client_id}: {e}") break except websockets.exceptions.ConnectionClosed: logger.info(f"Connection closed: {client_id}") except Exception as e: logger.error(f"Unexpected error with {client_id}: {e}") finally: self.active_connections.discard(websocket) print(f"\n{self.BOLD}{'='*80}{self.RESET}") print(f"{self.YELLOW}✗ Client disconnected: {client_id}{self.RESET}") print(f"{self.BOLD}{'='*80}{self.RESET}\n") sys.stdout.flush() logger.info(f"Connection closed: {client_id}") async def start(self): """Start the WebSocket server.""" self.print_header() async with websockets.serve(self.handle_client, self.host, self.port): logger.info(f"Starting WebSocket server on {self.host}:{self.port}") print(f"{self.GREEN}{self.BOLD}Server is running with VAD enabled!{self.RESET}") print(f"{self.BOLD}Ready for Discord bot connections...{self.RESET}\n") sys.stdout.flush() # Keep server running await asyncio.Future() def main(): """Main entry point.""" import argparse parser = argparse.ArgumentParser(description="ASR Server with VAD for Discord") parser.add_argument("--host", default="0.0.0.0", help="Host address") parser.add_argument("--port", type=int, default=8766, help="Port number") parser.add_argument("--model-path", default="models/parakeet", help="Model directory") parser.add_argument("--sample-rate", type=int, default=16000, help="Sample rate") args = parser.parse_args() server = VADServer( host=args.host, port=args.port, model_path=args.model_path, sample_rate=args.sample_rate, ) try: asyncio.run(server.start()) except KeyboardInterrupt: print(f"\n\n{server.YELLOW}Server stopped by user{server.RESET}") logger.info("Server stopped by user") if __name__ == "__main__": main()