Files
miku-discord/bot/utils/stt_client.py

228 lines
7.7 KiB
Python

"""
STT Client for Discord Bot (RealtimeSTT Version)
WebSocket client that connects to the RealtimeSTT server and handles:
- Audio streaming to STT
- Receiving partial/final transcripts
Protocol:
- Client sends: binary audio data (16kHz, 16-bit mono PCM)
- Client sends: JSON {"command": "reset"} to reset state
- Server sends: JSON {"type": "partial", "text": "...", "timestamp": float}
- Server sends: JSON {"type": "final", "text": "...", "timestamp": float}
"""
import aiohttp
import asyncio
import logging
from typing import Optional, Callable
import json
logger = logging.getLogger('stt_client')
class STTClient:
"""
WebSocket client for RealtimeSTT server communication.
Handles audio streaming and receives transcription events.
"""
def __init__(
self,
user_id: str,
stt_url: str = "ws://miku-stt:8766",
on_partial_transcript: Optional[Callable] = None,
on_final_transcript: Optional[Callable] = None,
):
"""
Initialize STT client.
Args:
user_id: Discord user ID (for logging purposes)
stt_url: WebSocket URL for STT server
on_partial_transcript: Callback for partial transcripts (text, timestamp)
on_final_transcript: Callback for final transcripts (text, timestamp)
"""
self.user_id = user_id
self.stt_url = stt_url
# Callbacks
self.on_partial_transcript = on_partial_transcript
self.on_final_transcript = on_final_transcript
# Connection state
self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None
self.session: Optional[aiohttp.ClientSession] = None
self.connected = False
self.running = False
# Receive task
self._receive_task: Optional[asyncio.Task] = None
logger.info(f"STT client initialized for user {user_id}")
async def connect(self):
"""Connect to RealtimeSTT WebSocket server."""
if self.connected:
logger.warning(f"Already connected for user {self.user_id}")
return
try:
self.session = aiohttp.ClientSession()
self.websocket = await self.session.ws_connect(
self.stt_url,
heartbeat=30,
receive_timeout=60
)
self.connected = True
self.running = True
# Start background task to receive messages
self._receive_task = asyncio.create_task(self._receive_loop())
logger.info(f"Connected to STT server at {self.stt_url} for user {self.user_id}")
except Exception as e:
logger.error(f"Failed to connect to STT server: {e}")
await self._cleanup()
raise
async def disconnect(self):
"""Disconnect from STT server."""
self.running = False
if self._receive_task:
self._receive_task.cancel()
try:
await self._receive_task
except asyncio.CancelledError:
pass
self._receive_task = None
await self._cleanup()
logger.info(f"Disconnected from STT server for user {self.user_id}")
async def _cleanup(self):
"""Clean up WebSocket and session."""
if self.websocket:
try:
await self.websocket.close()
except Exception:
pass
self.websocket = None
if self.session:
try:
await self.session.close()
except Exception:
pass
self.session = None
self.connected = False
async def send_audio(self, audio_data: bytes):
"""
Send raw audio data to STT server.
Args:
audio_data: Raw PCM audio (16kHz, 16-bit mono, little-endian)
"""
if not self.connected or not self.websocket:
return
try:
await self.websocket.send_bytes(audio_data)
except Exception as e:
logger.error(f"Failed to send audio: {e}")
await self._cleanup()
async def reset(self):
"""Reset STT state (clear any pending transcription)."""
if not self.connected or not self.websocket:
return
try:
await self.websocket.send_json({"command": "reset"})
logger.debug(f"Sent reset command for user {self.user_id}")
except Exception as e:
logger.error(f"Failed to send reset: {e}")
def is_connected(self) -> bool:
"""Check if connected to STT server."""
return self.connected and self.websocket is not None
async def _receive_loop(self):
"""Background task to receive messages from STT server."""
try:
while self.running and self.websocket:
try:
msg = await asyncio.wait_for(
self.websocket.receive(),
timeout=5.0
)
if msg.type == aiohttp.WSMsgType.TEXT:
await self._handle_message(msg.data)
elif msg.type == aiohttp.WSMsgType.CLOSED:
logger.warning(f"STT WebSocket closed for user {self.user_id}")
break
elif msg.type == aiohttp.WSMsgType.ERROR:
logger.error(f"STT WebSocket error for user {self.user_id}")
break
except asyncio.TimeoutError:
# Timeout is fine, just continue
continue
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"Error in STT receive loop: {e}")
finally:
self.connected = False
async def _handle_message(self, data: str):
"""Handle a message from the STT server."""
try:
message = json.loads(data)
msg_type = message.get("type")
text = message.get("text", "")
timestamp = message.get("timestamp", 0)
if msg_type == "partial":
if self.on_partial_transcript and text:
await self._call_callback(
self.on_partial_transcript,
text,
timestamp
)
elif msg_type == "final":
if self.on_final_transcript and text:
await self._call_callback(
self.on_final_transcript,
text,
timestamp
)
elif msg_type == "connected":
logger.info(f"STT server confirmed connection for user {self.user_id}")
elif msg_type == "error":
error_msg = message.get("error", "Unknown error")
logger.error(f"STT server error: {error_msg}")
except json.JSONDecodeError:
logger.warning(f"Invalid JSON from STT server: {data[:100]}")
except Exception as e:
logger.error(f"Error handling STT message: {e}")
async def _call_callback(self, callback, *args):
"""Safely call a callback, handling both sync and async functions."""
try:
result = callback(*args)
if asyncio.iscoroutine(result):
await result
except Exception as e:
logger.error(f"Error in STT callback: {e}")