""" STT Client for Discord Bot WebSocket client that connects to the STT server and handles: - Audio streaming to STT - Receiving VAD events - Receiving partial/final transcripts - Interruption detection """ import aiohttp import asyncio import logging from typing import Optional, Callable import json logger = logging.getLogger('stt_client') class STTClient: """ WebSocket client for STT server communication. Handles audio streaming and receives transcription events. """ def __init__( self, user_id: str, stt_url: str = "ws://miku-stt:8000/ws/stt", on_vad_event: Optional[Callable] = None, on_partial_transcript: Optional[Callable] = None, on_final_transcript: Optional[Callable] = None, on_interruption: Optional[Callable] = None ): """ Initialize STT client. Args: user_id: Discord user ID stt_url: Base WebSocket URL for STT server on_vad_event: Callback for VAD events (event_dict) on_partial_transcript: Callback for partial transcripts (text, timestamp) on_final_transcript: Callback for final transcripts (text, timestamp) on_interruption: Callback for interruption detection (probability) """ self.user_id = user_id self.stt_url = f"{stt_url}/{user_id}" # Callbacks self.on_vad_event = on_vad_event self.on_partial_transcript = on_partial_transcript self.on_final_transcript = on_final_transcript self.on_interruption = on_interruption # Connection state self.websocket: Optional[aiohttp.ClientWebSocket] = None self.session: Optional[aiohttp.ClientSession] = None self.connected = False self.running = False # Receive task self._receive_task: Optional[asyncio.Task] = None logger.info(f"STT client initialized for user {user_id}") async def connect(self): """Connect to STT WebSocket server.""" if self.connected: logger.warning(f"Already connected for user {self.user_id}") return try: self.session = aiohttp.ClientSession() self.websocket = await self.session.ws_connect( self.stt_url, heartbeat=30 ) # Wait for ready message ready_msg = await self.websocket.receive_json() logger.info(f"STT connected for user {self.user_id}: {ready_msg}") self.connected = True self.running = True # Start receive task self._receive_task = asyncio.create_task(self._receive_events()) logger.info(f"✓ STT WebSocket connected for user {self.user_id}") except Exception as e: logger.error(f"Failed to connect STT for user {self.user_id}: {e}", exc_info=True) await self.disconnect() raise async def disconnect(self): """Disconnect from STT WebSocket.""" logger.info(f"Disconnecting STT for user {self.user_id}") self.running = False self.connected = False # Cancel receive task if self._receive_task and not self._receive_task.done(): self._receive_task.cancel() try: await self._receive_task except asyncio.CancelledError: pass # Close WebSocket if self.websocket: await self.websocket.close() self.websocket = None # Close session if self.session: await self.session.close() self.session = None logger.info(f"✓ STT disconnected for user {self.user_id}") async def send_audio(self, audio_data: bytes): """ Send audio chunk to STT server. Args: audio_data: PCM audio (int16, 16kHz mono) """ if not self.connected or not self.websocket: logger.warning(f"Cannot send audio, not connected for user {self.user_id}") return try: await self.websocket.send_bytes(audio_data) logger.debug(f"Sent {len(audio_data)} bytes to STT") except Exception as e: logger.error(f"Failed to send audio to STT: {e}") self.connected = False async def _receive_events(self): """Background task to receive events from STT server.""" try: while self.running and self.websocket: try: msg = await self.websocket.receive() if msg.type == aiohttp.WSMsgType.TEXT: event = json.loads(msg.data) await self._handle_event(event) elif msg.type == aiohttp.WSMsgType.CLOSED: logger.info(f"STT WebSocket closed for user {self.user_id}") break elif msg.type == aiohttp.WSMsgType.ERROR: logger.error(f"STT WebSocket error for user {self.user_id}") break except asyncio.CancelledError: break except Exception as e: logger.error(f"Error receiving STT event: {e}", exc_info=True) finally: self.connected = False logger.info(f"STT receive task ended for user {self.user_id}") async def _handle_event(self, event: dict): """ Handle incoming STT event. Args: event: Event dictionary from STT server """ event_type = event.get('type') if event_type == 'vad': # VAD event: speech detection logger.debug(f"VAD event: {event}") if self.on_vad_event: await self.on_vad_event(event) elif event_type == 'partial': # Partial transcript text = event.get('text', '') timestamp = event.get('timestamp', 0) logger.info(f"Partial transcript [{self.user_id}]: {text}") if self.on_partial_transcript: await self.on_partial_transcript(text, timestamp) elif event_type == 'final': # Final transcript text = event.get('text', '') timestamp = event.get('timestamp', 0) logger.info(f"Final transcript [{self.user_id}]: {text}") if self.on_final_transcript: await self.on_final_transcript(text, timestamp) elif event_type == 'interruption': # Interruption detected probability = event.get('probability', 0) logger.info(f"Interruption detected from user {self.user_id} (prob={probability:.3f})") if self.on_interruption: await self.on_interruption(probability) else: logger.warning(f"Unknown STT event type: {event_type}") def is_connected(self) -> bool: """Check if STT client is connected.""" return self.connected