""" STT Client for Discord Bot (RealtimeSTT Version) WebSocket client that connects to the RealtimeSTT server and handles: - Audio streaming to STT - Receiving partial/final transcripts Protocol: - Client sends: binary audio data (16kHz, 16-bit mono PCM) - Client sends: JSON {"command": "reset"} to reset state - Server sends: JSON {"type": "partial", "text": "...", "timestamp": float} - Server sends: JSON {"type": "final", "text": "...", "timestamp": float} """ import aiohttp import asyncio import logging from typing import Optional, Callable import json logger = logging.getLogger('stt_client') class STTClient: """ WebSocket client for RealtimeSTT server communication. Handles audio streaming and receives transcription events. """ def __init__( self, user_id: str, stt_url: str = "ws://miku-stt:8766", on_partial_transcript: Optional[Callable] = None, on_final_transcript: Optional[Callable] = None, ): """ Initialize STT client. Args: user_id: Discord user ID (for logging purposes) stt_url: WebSocket URL for STT server on_partial_transcript: Callback for partial transcripts (text, timestamp) on_final_transcript: Callback for final transcripts (text, timestamp) """ self.user_id = user_id self.stt_url = stt_url # Callbacks self.on_partial_transcript = on_partial_transcript self.on_final_transcript = on_final_transcript # Connection state self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None self.session: Optional[aiohttp.ClientSession] = None self.connected = False self.running = False # Receive task self._receive_task: Optional[asyncio.Task] = None logger.info(f"STT client initialized for user {user_id}") async def connect(self): """Connect to RealtimeSTT WebSocket server.""" if self.connected: logger.warning(f"Already connected for user {self.user_id}") return try: self.session = aiohttp.ClientSession() self.websocket = await self.session.ws_connect( self.stt_url, heartbeat=30, receive_timeout=60 ) self.connected = True self.running = True # Start background task to receive messages self._receive_task = asyncio.create_task(self._receive_loop()) logger.info(f"Connected to STT server at {self.stt_url} for user {self.user_id}") except Exception as e: logger.error(f"Failed to connect to STT server: {e}") await self._cleanup() raise async def disconnect(self): """Disconnect from STT server.""" self.running = False if self._receive_task: self._receive_task.cancel() try: await self._receive_task except asyncio.CancelledError: pass self._receive_task = None await self._cleanup() logger.info(f"Disconnected from STT server for user {self.user_id}") async def _cleanup(self): """Clean up WebSocket and session.""" if self.websocket: try: await self.websocket.close() except Exception: pass self.websocket = None if self.session: try: await self.session.close() except Exception: pass self.session = None self.connected = False async def send_audio(self, audio_data: bytes): """ Send raw audio data to STT server. Args: audio_data: Raw PCM audio (16kHz, 16-bit mono, little-endian) """ if not self.connected or not self.websocket: return try: await self.websocket.send_bytes(audio_data) except Exception as e: logger.error(f"Failed to send audio: {e}") await self._cleanup() async def reset(self): """Reset STT state (clear any pending transcription).""" if not self.connected or not self.websocket: return try: await self.websocket.send_json({"command": "reset"}) logger.debug(f"Sent reset command for user {self.user_id}") except Exception as e: logger.error(f"Failed to send reset: {e}") def is_connected(self) -> bool: """Check if connected to STT server.""" return self.connected and self.websocket is not None async def _receive_loop(self): """Background task to receive messages from STT server.""" try: while self.running and self.websocket: try: msg = await asyncio.wait_for( self.websocket.receive(), timeout=5.0 ) if msg.type == aiohttp.WSMsgType.TEXT: await self._handle_message(msg.data) elif msg.type == aiohttp.WSMsgType.CLOSED: logger.warning(f"STT WebSocket closed for user {self.user_id}") break elif msg.type == aiohttp.WSMsgType.ERROR: logger.error(f"STT WebSocket error for user {self.user_id}") break except asyncio.TimeoutError: # Timeout is fine, just continue continue except asyncio.CancelledError: pass except Exception as e: logger.error(f"Error in STT receive loop: {e}") finally: self.connected = False async def _handle_message(self, data: str): """Handle a message from the STT server.""" try: message = json.loads(data) msg_type = message.get("type") text = message.get("text", "") timestamp = message.get("timestamp", 0) if msg_type == "partial": if self.on_partial_transcript and text: await self._call_callback( self.on_partial_transcript, text, timestamp ) elif msg_type == "final": if self.on_final_transcript and text: await self._call_callback( self.on_final_transcript, text, timestamp ) elif msg_type == "connected": logger.info(f"STT server confirmed connection for user {self.user_id}") elif msg_type == "error": error_msg = message.get("error", "Unknown error") logger.error(f"STT server error: {error_msg}") except json.JSONDecodeError: logger.warning(f"Invalid JSON from STT server: {data[:100]}") except Exception as e: logger.error(f"Error handling STT message: {e}") async def _call_callback(self, callback, *args): """Safely call a callback, handling both sync and async functions.""" try: result = callback(*args) if asyncio.iscoroutine(result): await result except Exception as e: logger.error(f"Error in STT callback: {e}")