Implemented experimental real production ready voice chat, relegated old flow to voice debug mode. New Web UI panel for Voice Chat.
This commit is contained in:
@@ -1,11 +1,15 @@
|
||||
"""
|
||||
STT Client for Discord Bot
|
||||
STT Client for Discord Bot (RealtimeSTT Version)
|
||||
|
||||
WebSocket client that connects to the STT server and handles:
|
||||
WebSocket client that connects to the RealtimeSTT server and handles:
|
||||
- Audio streaming to STT
|
||||
- Receiving VAD events
|
||||
- Receiving partial/final transcripts
|
||||
- Interruption detection
|
||||
|
||||
Protocol:
|
||||
- Client sends: binary audio data (16kHz, 16-bit mono PCM)
|
||||
- Client sends: JSON {"command": "reset"} to reset state
|
||||
- Server sends: JSON {"type": "partial", "text": "...", "timestamp": float}
|
||||
- Server sends: JSON {"type": "final", "text": "...", "timestamp": float}
|
||||
"""
|
||||
|
||||
import aiohttp
|
||||
@@ -19,7 +23,7 @@ logger = logging.getLogger('stt_client')
|
||||
|
||||
class STTClient:
|
||||
"""
|
||||
WebSocket client for STT server communication.
|
||||
WebSocket client for RealtimeSTT server communication.
|
||||
|
||||
Handles audio streaming and receives transcription events.
|
||||
"""
|
||||
@@ -27,34 +31,28 @@ class STTClient:
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
stt_url: str = "ws://miku-stt:8766/ws/stt",
|
||||
on_vad_event: Optional[Callable] = None,
|
||||
stt_url: str = "ws://miku-stt:8766",
|
||||
on_partial_transcript: Optional[Callable] = None,
|
||||
on_final_transcript: Optional[Callable] = None,
|
||||
on_interruption: Optional[Callable] = None
|
||||
):
|
||||
"""
|
||||
Initialize STT client.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
stt_url: Base WebSocket URL for STT server
|
||||
on_vad_event: Callback for VAD events (event_dict)
|
||||
user_id: Discord user ID (for logging purposes)
|
||||
stt_url: WebSocket URL for STT server
|
||||
on_partial_transcript: Callback for partial transcripts (text, timestamp)
|
||||
on_final_transcript: Callback for final transcripts (text, timestamp)
|
||||
on_interruption: Callback for interruption detection (probability)
|
||||
"""
|
||||
self.user_id = user_id
|
||||
self.stt_url = f"{stt_url}/{user_id}"
|
||||
self.stt_url = stt_url
|
||||
|
||||
# Callbacks
|
||||
self.on_vad_event = on_vad_event
|
||||
self.on_partial_transcript = on_partial_transcript
|
||||
self.on_final_transcript = on_final_transcript
|
||||
self.on_interruption = on_interruption
|
||||
|
||||
# Connection state
|
||||
self.websocket: Optional[aiohttp.ClientWebSocket] = None
|
||||
self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None
|
||||
self.session: Optional[aiohttp.ClientSession] = None
|
||||
self.connected = False
|
||||
self.running = False
|
||||
@@ -65,7 +63,7 @@ class STTClient:
|
||||
logger.info(f"STT client initialized for user {user_id}")
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to STT WebSocket server."""
|
||||
"""Connect to RealtimeSTT WebSocket server."""
|
||||
if self.connected:
|
||||
logger.warning(f"Already connected for user {self.user_id}")
|
||||
return
|
||||
@@ -74,202 +72,156 @@ class STTClient:
|
||||
self.session = aiohttp.ClientSession()
|
||||
self.websocket = await self.session.ws_connect(
|
||||
self.stt_url,
|
||||
heartbeat=30
|
||||
heartbeat=30,
|
||||
receive_timeout=60
|
||||
)
|
||||
|
||||
# Wait for ready message
|
||||
ready_msg = await self.websocket.receive_json()
|
||||
logger.info(f"STT connected for user {self.user_id}: {ready_msg}")
|
||||
|
||||
self.connected = True
|
||||
self.running = True
|
||||
|
||||
# Start receive task
|
||||
self._receive_task = asyncio.create_task(self._receive_events())
|
||||
# Start background task to receive messages
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
logger.info(f"✓ STT WebSocket connected for user {self.user_id}")
|
||||
|
||||
logger.info(f"Connected to STT server at {self.stt_url} for user {self.user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect STT for user {self.user_id}: {e}", exc_info=True)
|
||||
await self.disconnect()
|
||||
logger.error(f"Failed to connect to STT server: {e}")
|
||||
await self._cleanup()
|
||||
raise
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from STT WebSocket."""
|
||||
logger.info(f"Disconnecting STT for user {self.user_id}")
|
||||
|
||||
"""Disconnect from STT server."""
|
||||
self.running = False
|
||||
self.connected = False
|
||||
|
||||
# Cancel receive task
|
||||
if self._receive_task and not self._receive_task.done():
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._receive_task = None
|
||||
|
||||
# Close WebSocket
|
||||
await self._cleanup()
|
||||
logger.info(f"Disconnected from STT server for user {self.user_id}")
|
||||
|
||||
async def _cleanup(self):
|
||||
"""Clean up WebSocket and session."""
|
||||
if self.websocket:
|
||||
await self.websocket.close()
|
||||
try:
|
||||
await self.websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.websocket = None
|
||||
|
||||
# Close session
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
try:
|
||||
await self.session.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.session = None
|
||||
|
||||
logger.info(f"✓ STT disconnected for user {self.user_id}")
|
||||
self.connected = False
|
||||
|
||||
async def send_audio(self, audio_data: bytes):
|
||||
"""
|
||||
Send audio chunk to STT server.
|
||||
Send raw audio data to STT server.
|
||||
|
||||
Args:
|
||||
audio_data: PCM audio (int16, 16kHz mono)
|
||||
audio_data: Raw PCM audio (16kHz, 16-bit mono, little-endian)
|
||||
"""
|
||||
if not self.connected or not self.websocket:
|
||||
logger.warning(f"Cannot send audio, not connected for user {self.user_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
await self.websocket.send_bytes(audio_data)
|
||||
logger.debug(f"Sent {len(audio_data)} bytes to STT")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send audio to STT: {e}")
|
||||
self.connected = False
|
||||
logger.error(f"Failed to send audio: {e}")
|
||||
await self._cleanup()
|
||||
|
||||
async def send_final(self):
|
||||
"""
|
||||
Request final transcription from STT server.
|
||||
|
||||
Call this when the user stops speaking to get the final transcript.
|
||||
"""
|
||||
async def reset(self):
|
||||
"""Reset STT state (clear any pending transcription)."""
|
||||
if not self.connected or not self.websocket:
|
||||
logger.warning(f"Cannot send final command, not connected for user {self.user_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
command = json.dumps({"type": "final"})
|
||||
await self.websocket.send_str(command)
|
||||
logger.debug(f"Sent final command to STT")
|
||||
|
||||
await self.websocket.send_json({"command": "reset"})
|
||||
logger.debug(f"Sent reset command for user {self.user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send final command to STT: {e}")
|
||||
self.connected = False
|
||||
logger.error(f"Failed to send reset: {e}")
|
||||
|
||||
async def send_reset(self):
|
||||
"""
|
||||
Reset the STT server's audio buffer.
|
||||
|
||||
Call this to clear any buffered audio.
|
||||
"""
|
||||
if not self.connected or not self.websocket:
|
||||
logger.warning(f"Cannot send reset command, not connected for user {self.user_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
command = json.dumps({"type": "reset"})
|
||||
await self.websocket.send_str(command)
|
||||
logger.debug(f"Sent reset command to STT")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send reset command to STT: {e}")
|
||||
self.connected = False
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if connected to STT server."""
|
||||
return self.connected and self.websocket is not None
|
||||
|
||||
async def _receive_events(self):
|
||||
"""Background task to receive events from STT server."""
|
||||
async def _receive_loop(self):
|
||||
"""Background task to receive messages from STT server."""
|
||||
try:
|
||||
while self.running and self.websocket:
|
||||
try:
|
||||
msg = await self.websocket.receive()
|
||||
msg = await asyncio.wait_for(
|
||||
self.websocket.receive(),
|
||||
timeout=5.0
|
||||
)
|
||||
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
event = json.loads(msg.data)
|
||||
await self._handle_event(event)
|
||||
|
||||
await self._handle_message(msg.data)
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
logger.info(f"STT WebSocket closed for user {self.user_id}")
|
||||
logger.warning(f"STT WebSocket closed for user {self.user_id}")
|
||||
break
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
logger.error(f"STT WebSocket error for user {self.user_id}")
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving STT event: {e}", exc_info=True)
|
||||
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Timeout is fine, just continue
|
||||
continue
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error in STT receive loop: {e}")
|
||||
finally:
|
||||
self.connected = False
|
||||
logger.info(f"STT receive task ended for user {self.user_id}")
|
||||
|
||||
async def _handle_event(self, event: dict):
|
||||
"""
|
||||
Handle incoming STT event.
|
||||
|
||||
Args:
|
||||
event: Event dictionary from STT server
|
||||
"""
|
||||
event_type = event.get('type')
|
||||
|
||||
if event_type == 'transcript':
|
||||
# New ONNX server protocol: single transcript type with is_final flag
|
||||
text = event.get('text', '')
|
||||
is_final = event.get('is_final', False)
|
||||
timestamp = event.get('timestamp', 0)
|
||||
async def _handle_message(self, data: str):
|
||||
"""Handle a message from the STT server."""
|
||||
try:
|
||||
message = json.loads(data)
|
||||
msg_type = message.get("type")
|
||||
text = message.get("text", "")
|
||||
timestamp = message.get("timestamp", 0)
|
||||
|
||||
if is_final:
|
||||
logger.info(f"Final transcript [{self.user_id}]: {text}")
|
||||
if self.on_final_transcript:
|
||||
await self.on_final_transcript(text, timestamp)
|
||||
else:
|
||||
logger.info(f"Partial transcript [{self.user_id}]: {text}")
|
||||
if self.on_partial_transcript:
|
||||
await self.on_partial_transcript(text, timestamp)
|
||||
|
||||
elif event_type == 'vad':
|
||||
# VAD event: speech detection (legacy support)
|
||||
logger.debug(f"VAD event: {event}")
|
||||
if self.on_vad_event:
|
||||
await self.on_vad_event(event)
|
||||
|
||||
elif event_type == 'partial':
|
||||
# Legacy protocol support: partial transcript
|
||||
text = event.get('text', '')
|
||||
timestamp = event.get('timestamp', 0)
|
||||
logger.info(f"Partial transcript [{self.user_id}]: {text}")
|
||||
if self.on_partial_transcript:
|
||||
await self.on_partial_transcript(text, timestamp)
|
||||
|
||||
elif event_type == 'final':
|
||||
# Legacy protocol support: final transcript
|
||||
text = event.get('text', '')
|
||||
timestamp = event.get('timestamp', 0)
|
||||
logger.info(f"Final transcript [{self.user_id}]: {text}")
|
||||
if self.on_final_transcript:
|
||||
await self.on_final_transcript(text, timestamp)
|
||||
|
||||
elif event_type == 'interruption':
|
||||
# Interruption detected (legacy support)
|
||||
probability = event.get('probability', 0)
|
||||
logger.info(f"Interruption detected from user {self.user_id} (prob={probability:.3f})")
|
||||
if self.on_interruption:
|
||||
await self.on_interruption(probability)
|
||||
|
||||
elif event_type == 'info':
|
||||
# Info message
|
||||
logger.info(f"STT info: {event.get('message', '')}")
|
||||
|
||||
elif event_type == 'error':
|
||||
# Error message
|
||||
logger.error(f"STT error: {event.get('message', '')}")
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown STT event type: {event_type}")
|
||||
if msg_type == "partial":
|
||||
if self.on_partial_transcript and text:
|
||||
await self._call_callback(
|
||||
self.on_partial_transcript,
|
||||
text,
|
||||
timestamp
|
||||
)
|
||||
|
||||
elif msg_type == "final":
|
||||
if self.on_final_transcript and text:
|
||||
await self._call_callback(
|
||||
self.on_final_transcript,
|
||||
text,
|
||||
timestamp
|
||||
)
|
||||
|
||||
elif msg_type == "connected":
|
||||
logger.info(f"STT server confirmed connection for user {self.user_id}")
|
||||
|
||||
elif msg_type == "error":
|
||||
error_msg = message.get("error", "Unknown error")
|
||||
logger.error(f"STT server error: {error_msg}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON from STT server: {data[:100]}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling STT message: {e}")
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if STT client is connected."""
|
||||
return self.connected
|
||||
async def _call_callback(self, callback, *args):
|
||||
"""Safely call a callback, handling both sync and async functions."""
|
||||
try:
|
||||
result = callback(*args)
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
except Exception as e:
|
||||
logger.error(f"Error in STT callback: {e}")
|
||||
|
||||
Reference in New Issue
Block a user