Implemented experimental real production ready voice chat, relegated old flow to voice debug mode. New Web UI panel for Voice Chat.
This commit is contained in:
275
backups/2025-01-19-stt-parakeet/bot/utils/stt_client.py
Normal file
275
backups/2025-01-19-stt-parakeet/bot/utils/stt_client.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""
|
||||
STT Client for Discord Bot
|
||||
|
||||
WebSocket client that connects to the STT server and handles:
|
||||
- Audio streaming to STT
|
||||
- Receiving VAD events
|
||||
- Receiving partial/final transcripts
|
||||
- Interruption detection
|
||||
"""
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Callable
|
||||
import json
|
||||
|
||||
logger = logging.getLogger('stt_client')
|
||||
|
||||
|
||||
class STTClient:
|
||||
"""
|
||||
WebSocket client for STT server communication.
|
||||
|
||||
Handles audio streaming and receives transcription events.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
stt_url: str = "ws://miku-stt:8766/ws/stt",
|
||||
on_vad_event: Optional[Callable] = None,
|
||||
on_partial_transcript: Optional[Callable] = None,
|
||||
on_final_transcript: Optional[Callable] = None,
|
||||
on_interruption: Optional[Callable] = None
|
||||
):
|
||||
"""
|
||||
Initialize STT client.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
stt_url: Base WebSocket URL for STT server
|
||||
on_vad_event: Callback for VAD events (event_dict)
|
||||
on_partial_transcript: Callback for partial transcripts (text, timestamp)
|
||||
on_final_transcript: Callback for final transcripts (text, timestamp)
|
||||
on_interruption: Callback for interruption detection (probability)
|
||||
"""
|
||||
self.user_id = user_id
|
||||
self.stt_url = f"{stt_url}/{user_id}"
|
||||
|
||||
# Callbacks
|
||||
self.on_vad_event = on_vad_event
|
||||
self.on_partial_transcript = on_partial_transcript
|
||||
self.on_final_transcript = on_final_transcript
|
||||
self.on_interruption = on_interruption
|
||||
|
||||
# Connection state
|
||||
self.websocket: Optional[aiohttp.ClientWebSocket] = None
|
||||
self.session: Optional[aiohttp.ClientSession] = None
|
||||
self.connected = False
|
||||
self.running = False
|
||||
|
||||
# Receive task
|
||||
self._receive_task: Optional[asyncio.Task] = None
|
||||
|
||||
logger.info(f"STT client initialized for user {user_id}")
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to STT WebSocket server."""
|
||||
if self.connected:
|
||||
logger.warning(f"Already connected for user {self.user_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
self.session = aiohttp.ClientSession()
|
||||
self.websocket = await self.session.ws_connect(
|
||||
self.stt_url,
|
||||
heartbeat=30
|
||||
)
|
||||
|
||||
# Wait for ready message
|
||||
ready_msg = await self.websocket.receive_json()
|
||||
logger.info(f"STT connected for user {self.user_id}: {ready_msg}")
|
||||
|
||||
self.connected = True
|
||||
self.running = True
|
||||
|
||||
# Start receive task
|
||||
self._receive_task = asyncio.create_task(self._receive_events())
|
||||
|
||||
logger.info(f"✓ STT WebSocket connected for user {self.user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect STT for user {self.user_id}: {e}", exc_info=True)
|
||||
await self.disconnect()
|
||||
raise
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from STT WebSocket."""
|
||||
logger.info(f"Disconnecting STT for user {self.user_id}")
|
||||
|
||||
self.running = False
|
||||
self.connected = False
|
||||
|
||||
# Cancel receive task
|
||||
if self._receive_task and not self._receive_task.done():
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Close WebSocket
|
||||
if self.websocket:
|
||||
await self.websocket.close()
|
||||
self.websocket = None
|
||||
|
||||
# Close session
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
self.session = None
|
||||
|
||||
logger.info(f"✓ STT disconnected for user {self.user_id}")
|
||||
|
||||
async def send_audio(self, audio_data: bytes):
|
||||
"""
|
||||
Send audio chunk to STT server.
|
||||
|
||||
Args:
|
||||
audio_data: PCM audio (int16, 16kHz mono)
|
||||
"""
|
||||
if not self.connected or not self.websocket:
|
||||
logger.warning(f"Cannot send audio, not connected for user {self.user_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
await self.websocket.send_bytes(audio_data)
|
||||
logger.debug(f"Sent {len(audio_data)} bytes to STT")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send audio to STT: {e}")
|
||||
self.connected = False
|
||||
|
||||
async def send_final(self):
|
||||
"""
|
||||
Request final transcription from STT server.
|
||||
|
||||
Call this when the user stops speaking to get the final transcript.
|
||||
"""
|
||||
if not self.connected or not self.websocket:
|
||||
logger.warning(f"Cannot send final command, not connected for user {self.user_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
command = json.dumps({"type": "final"})
|
||||
await self.websocket.send_str(command)
|
||||
logger.debug(f"Sent final command to STT")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send final command to STT: {e}")
|
||||
self.connected = False
|
||||
|
||||
async def send_reset(self):
|
||||
"""
|
||||
Reset the STT server's audio buffer.
|
||||
|
||||
Call this to clear any buffered audio.
|
||||
"""
|
||||
if not self.connected or not self.websocket:
|
||||
logger.warning(f"Cannot send reset command, not connected for user {self.user_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
command = json.dumps({"type": "reset"})
|
||||
await self.websocket.send_str(command)
|
||||
logger.debug(f"Sent reset command to STT")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send reset command to STT: {e}")
|
||||
self.connected = False
|
||||
|
||||
async def _receive_events(self):
|
||||
"""Background task to receive events from STT server."""
|
||||
try:
|
||||
while self.running and self.websocket:
|
||||
try:
|
||||
msg = await self.websocket.receive()
|
||||
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
event = json.loads(msg.data)
|
||||
await self._handle_event(event)
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
logger.info(f"STT WebSocket closed for user {self.user_id}")
|
||||
break
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
logger.error(f"STT WebSocket error for user {self.user_id}")
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving STT event: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
self.connected = False
|
||||
logger.info(f"STT receive task ended for user {self.user_id}")
|
||||
|
||||
async def _handle_event(self, event: dict):
|
||||
"""
|
||||
Handle incoming STT event.
|
||||
|
||||
Args:
|
||||
event: Event dictionary from STT server
|
||||
"""
|
||||
event_type = event.get('type')
|
||||
|
||||
if event_type == 'transcript':
|
||||
# New ONNX server protocol: single transcript type with is_final flag
|
||||
text = event.get('text', '')
|
||||
is_final = event.get('is_final', False)
|
||||
timestamp = event.get('timestamp', 0)
|
||||
|
||||
if is_final:
|
||||
logger.info(f"Final transcript [{self.user_id}]: {text}")
|
||||
if self.on_final_transcript:
|
||||
await self.on_final_transcript(text, timestamp)
|
||||
else:
|
||||
logger.info(f"Partial transcript [{self.user_id}]: {text}")
|
||||
if self.on_partial_transcript:
|
||||
await self.on_partial_transcript(text, timestamp)
|
||||
|
||||
elif event_type == 'vad':
|
||||
# VAD event: speech detection (legacy support)
|
||||
logger.debug(f"VAD event: {event}")
|
||||
if self.on_vad_event:
|
||||
await self.on_vad_event(event)
|
||||
|
||||
elif event_type == 'partial':
|
||||
# Legacy protocol support: partial transcript
|
||||
text = event.get('text', '')
|
||||
timestamp = event.get('timestamp', 0)
|
||||
logger.info(f"Partial transcript [{self.user_id}]: {text}")
|
||||
if self.on_partial_transcript:
|
||||
await self.on_partial_transcript(text, timestamp)
|
||||
|
||||
elif event_type == 'final':
|
||||
# Legacy protocol support: final transcript
|
||||
text = event.get('text', '')
|
||||
timestamp = event.get('timestamp', 0)
|
||||
logger.info(f"Final transcript [{self.user_id}]: {text}")
|
||||
if self.on_final_transcript:
|
||||
await self.on_final_transcript(text, timestamp)
|
||||
|
||||
elif event_type == 'interruption':
|
||||
# Interruption detected (legacy support)
|
||||
probability = event.get('probability', 0)
|
||||
logger.info(f"Interruption detected from user {self.user_id} (prob={probability:.3f})")
|
||||
if self.on_interruption:
|
||||
await self.on_interruption(probability)
|
||||
|
||||
elif event_type == 'info':
|
||||
# Info message
|
||||
logger.info(f"STT info: {event.get('message', '')}")
|
||||
|
||||
elif event_type == 'error':
|
||||
# Error message
|
||||
logger.error(f"STT error: {event.get('message', '')}")
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown STT event type: {event_type}")
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if STT client is connected."""
|
||||
return self.connected
|
||||
518
backups/2025-01-19-stt-parakeet/bot/utils/voice_receiver.py
Normal file
518
backups/2025-01-19-stt-parakeet/bot/utils/voice_receiver.py
Normal file
@@ -0,0 +1,518 @@
|
||||
"""
|
||||
Discord Voice Receiver using discord-ext-voice-recv
|
||||
|
||||
Captures audio from Discord voice channels and streams to STT.
|
||||
Uses the discord-ext-voice-recv extension for proper audio receiving support.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import audioop
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
from collections import deque
|
||||
|
||||
import discord
|
||||
from discord.ext import voice_recv
|
||||
|
||||
from utils.stt_client import STTClient
|
||||
|
||||
logger = logging.getLogger('voice_receiver')
|
||||
|
||||
|
||||
class VoiceReceiverSink(voice_recv.AudioSink):
|
||||
"""
|
||||
Audio sink that receives Discord audio and forwards to STT.
|
||||
|
||||
This sink processes incoming audio from Discord voice channels,
|
||||
decodes/resamples as needed, and sends to STT clients for transcription.
|
||||
"""
|
||||
|
||||
def __init__(self, voice_manager, stt_url: str = "ws://miku-stt:8766/ws/stt"):
|
||||
"""
|
||||
Initialize Voice Receiver.
|
||||
|
||||
Args:
|
||||
voice_manager: The voice manager instance
|
||||
stt_url: Base URL for STT WebSocket server with path (port 8766 inside container)
|
||||
"""
|
||||
super().__init__()
|
||||
self.voice_manager = voice_manager
|
||||
self.stt_url = stt_url
|
||||
|
||||
# Store event loop for thread-safe async calls
|
||||
# Use get_running_loop() in async context, or store it when available
|
||||
try:
|
||||
self.loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# Fallback if not in async context yet
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
# Per-user STT clients
|
||||
self.stt_clients: Dict[int, STTClient] = {}
|
||||
|
||||
# Audio buffers per user (for resampling state)
|
||||
self.audio_buffers: Dict[int, deque] = {}
|
||||
|
||||
# User info (for logging)
|
||||
self.users: Dict[int, discord.User] = {}
|
||||
|
||||
# Silence tracking for detecting end of speech
|
||||
self.last_audio_time: Dict[int, float] = {}
|
||||
self.silence_tasks: Dict[int, asyncio.Task] = {}
|
||||
self.silence_timeout = 1.0 # seconds of silence before sending "final"
|
||||
|
||||
# Interruption detection
|
||||
self.interruption_start_time: Dict[int, float] = {}
|
||||
self.interruption_audio_count: Dict[int, int] = {}
|
||||
self.interruption_threshold_time = 0.8 # seconds of speech to count as interruption
|
||||
self.interruption_threshold_chunks = 8 # minimum audio chunks to count as interruption
|
||||
|
||||
# Active flag
|
||||
self.active = False
|
||||
|
||||
logger.info("VoiceReceiverSink initialized")
|
||||
|
||||
def wants_opus(self) -> bool:
|
||||
"""
|
||||
Tell discord-ext-voice-recv we want Opus data, NOT decoded PCM.
|
||||
|
||||
We'll decode it ourselves to avoid decoder errors from discord-ext-voice-recv.
|
||||
|
||||
Returns:
|
||||
True - we want Opus packets, we'll handle decoding
|
||||
"""
|
||||
return True # Get Opus, decode ourselves to avoid packet router errors
|
||||
|
||||
def write(self, user: Optional[discord.User], data: voice_recv.VoiceData):
|
||||
"""
|
||||
Called by discord-ext-voice-recv when audio is received.
|
||||
|
||||
This is the main callback that receives audio packets from Discord.
|
||||
We get Opus data, decode it ourselves, resample, and forward to STT.
|
||||
|
||||
Args:
|
||||
user: Discord user who sent the audio (None if unknown)
|
||||
data: Voice data container with pcm, opus, and packet info
|
||||
"""
|
||||
if not user:
|
||||
return # Skip packets from unknown users
|
||||
|
||||
user_id = user.id
|
||||
|
||||
# Check if we're listening to this user
|
||||
if user_id not in self.stt_clients:
|
||||
return
|
||||
|
||||
try:
|
||||
# Get Opus data (we decode ourselves to avoid PacketRouter errors)
|
||||
opus_data = data.opus
|
||||
|
||||
if not opus_data:
|
||||
return
|
||||
|
||||
# Decode Opus to PCM (48kHz stereo int16)
|
||||
# Use discord.py's opus decoder with proper error handling
|
||||
import discord.opus
|
||||
if not hasattr(self, '_opus_decoders'):
|
||||
self._opus_decoders = {}
|
||||
|
||||
# Create decoder for this user if needed
|
||||
if user_id not in self._opus_decoders:
|
||||
self._opus_decoders[user_id] = discord.opus.Decoder()
|
||||
|
||||
decoder = self._opus_decoders[user_id]
|
||||
|
||||
# Decode opus -> PCM (this can fail on corrupt packets, so catch it)
|
||||
try:
|
||||
pcm_data = decoder.decode(opus_data, fec=False)
|
||||
except discord.opus.OpusError as e:
|
||||
# Skip corrupted packets silently (common at stream start)
|
||||
logger.debug(f"Skipping corrupted opus packet for user {user_id}: {e}")
|
||||
return
|
||||
|
||||
if not pcm_data:
|
||||
return
|
||||
|
||||
# PCM from Discord is 48kHz stereo int16
|
||||
# Convert stereo to mono
|
||||
if len(pcm_data) % 4 == 0: # Stereo (2 channels * 2 bytes per sample)
|
||||
pcm_mono = audioop.tomono(pcm_data, 2, 0.5, 0.5)
|
||||
else:
|
||||
pcm_mono = pcm_data
|
||||
|
||||
# Resample from 48kHz to 16kHz for STT
|
||||
# Discord sends 20ms chunks: 960 samples @ 48kHz → 320 samples @ 16kHz
|
||||
pcm_16k, _ = audioop.ratecv(pcm_mono, 2, 1, 48000, 16000, None)
|
||||
|
||||
# Send to STT client (schedule on event loop thread-safely)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._send_audio_chunk(user_id, pcm_16k),
|
||||
self.loop
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio for user {user_id}: {e}", exc_info=True)
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
Called when the sink is stopped.
|
||||
Cleanup any resources.
|
||||
"""
|
||||
logger.info("VoiceReceiverSink cleanup")
|
||||
# Async cleanup handled separately in stop_all()
|
||||
|
||||
async def start_listening(self, user_id: int, user: discord.User):
|
||||
"""
|
||||
Start listening to a specific user.
|
||||
|
||||
Creates an STT client connection for this user and registers callbacks.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
user: Discord user object
|
||||
"""
|
||||
if user_id in self.stt_clients:
|
||||
logger.warning(f"Already listening to user {user.name} ({user_id})")
|
||||
return
|
||||
|
||||
logger.info(f"Starting to listen to user {user.name} ({user_id})")
|
||||
|
||||
# Store user info
|
||||
self.users[user_id] = user
|
||||
|
||||
# Initialize audio buffer
|
||||
self.audio_buffers[user_id] = deque(maxlen=1000)
|
||||
|
||||
# Create STT client with callbacks
|
||||
stt_client = STTClient(
|
||||
user_id=user_id,
|
||||
stt_url=self.stt_url,
|
||||
on_vad_event=lambda event: asyncio.create_task(
|
||||
self._on_vad_event(user_id, event)
|
||||
),
|
||||
on_partial_transcript=lambda text, timestamp: asyncio.create_task(
|
||||
self._on_partial_transcript(user_id, text)
|
||||
),
|
||||
on_final_transcript=lambda text, timestamp: asyncio.create_task(
|
||||
self._on_final_transcript(user_id, text, user)
|
||||
),
|
||||
on_interruption=lambda prob: asyncio.create_task(
|
||||
self._on_interruption(user_id, prob)
|
||||
)
|
||||
)
|
||||
|
||||
# Connect to STT server
|
||||
try:
|
||||
await stt_client.connect()
|
||||
self.stt_clients[user_id] = stt_client
|
||||
self.active = True
|
||||
logger.info(f"✓ STT connected for user {user.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect STT for user {user.name}: {e}", exc_info=True)
|
||||
# Cleanup partial state
|
||||
if user_id in self.audio_buffers:
|
||||
del self.audio_buffers[user_id]
|
||||
if user_id in self.users:
|
||||
del self.users[user_id]
|
||||
raise
|
||||
|
||||
async def stop_listening(self, user_id: int):
|
||||
"""
|
||||
Stop listening to a specific user.
|
||||
|
||||
Disconnects the STT client and cleans up resources for this user.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
"""
|
||||
if user_id not in self.stt_clients:
|
||||
logger.warning(f"Not listening to user {user_id}")
|
||||
return
|
||||
|
||||
user = self.users.get(user_id)
|
||||
logger.info(f"Stopping listening to user {user.name if user else user_id}")
|
||||
|
||||
# Disconnect STT client
|
||||
stt_client = self.stt_clients[user_id]
|
||||
await stt_client.disconnect()
|
||||
|
||||
# Cleanup
|
||||
del self.stt_clients[user_id]
|
||||
if user_id in self.audio_buffers:
|
||||
del self.audio_buffers[user_id]
|
||||
if user_id in self.users:
|
||||
del self.users[user_id]
|
||||
|
||||
# Cancel silence detection task
|
||||
if user_id in self.silence_tasks and not self.silence_tasks[user_id].done():
|
||||
self.silence_tasks[user_id].cancel()
|
||||
del self.silence_tasks[user_id]
|
||||
if user_id in self.last_audio_time:
|
||||
del self.last_audio_time[user_id]
|
||||
|
||||
# Clear interruption tracking
|
||||
self.interruption_start_time.pop(user_id, None)
|
||||
self.interruption_audio_count.pop(user_id, None)
|
||||
|
||||
# Cleanup opus decoder for this user
|
||||
if hasattr(self, '_opus_decoders') and user_id in self._opus_decoders:
|
||||
del self._opus_decoders[user_id]
|
||||
|
||||
# Update active flag
|
||||
if not self.stt_clients:
|
||||
self.active = False
|
||||
|
||||
logger.info(f"✓ Stopped listening to user {user.name if user else user_id}")
|
||||
|
||||
async def stop_all(self):
|
||||
"""Stop listening to all users and cleanup all resources."""
|
||||
logger.info("Stopping all voice receivers")
|
||||
|
||||
user_ids = list(self.stt_clients.keys())
|
||||
for user_id in user_ids:
|
||||
await self.stop_listening(user_id)
|
||||
|
||||
self.active = False
|
||||
logger.info("✓ All voice receivers stopped")
|
||||
|
||||
async def _send_audio_chunk(self, user_id: int, audio_data: bytes):
|
||||
"""
|
||||
Send audio chunk to STT client.
|
||||
|
||||
Buffers audio until we have 512 samples (32ms @ 16kHz) which is what
|
||||
Silero VAD expects. Discord sends 320 samples (20ms), so we buffer
|
||||
2 chunks and send 640 samples, then the STT server can split it.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
audio_data: PCM audio (int16, 16kHz mono, 320 samples = 640 bytes)
|
||||
"""
|
||||
stt_client = self.stt_clients.get(user_id)
|
||||
if not stt_client or not stt_client.is_connected():
|
||||
return
|
||||
|
||||
try:
|
||||
# Get or create buffer for this user
|
||||
if user_id not in self.audio_buffers:
|
||||
self.audio_buffers[user_id] = deque()
|
||||
|
||||
buffer = self.audio_buffers[user_id]
|
||||
buffer.append(audio_data)
|
||||
|
||||
# Silero VAD expects 512 samples @ 16kHz (1024 bytes)
|
||||
# Discord gives us 320 samples (640 bytes) every 20ms
|
||||
# Buffer 2 chunks = 640 samples = 1280 bytes, send as one chunk
|
||||
SAMPLES_NEEDED = 512 # What VAD wants
|
||||
BYTES_NEEDED = SAMPLES_NEEDED * 2 # int16 = 2 bytes per sample
|
||||
|
||||
# Check if we have enough buffered audio
|
||||
total_bytes = sum(len(chunk) for chunk in buffer)
|
||||
|
||||
if total_bytes >= BYTES_NEEDED:
|
||||
# Concatenate buffered chunks
|
||||
combined = b''.join(buffer)
|
||||
buffer.clear()
|
||||
|
||||
# Send in 512-sample (1024-byte) chunks
|
||||
for i in range(0, len(combined), BYTES_NEEDED):
|
||||
chunk = combined[i:i+BYTES_NEEDED]
|
||||
if len(chunk) == BYTES_NEEDED:
|
||||
await stt_client.send_audio(chunk)
|
||||
else:
|
||||
# Put remaining partial chunk back in buffer
|
||||
buffer.append(chunk)
|
||||
|
||||
# Track audio time for silence detection
|
||||
import time
|
||||
current_time = time.time()
|
||||
self.last_audio_time[user_id] = current_time
|
||||
|
||||
# ===== INTERRUPTION DETECTION =====
|
||||
# Check if Miku is speaking and user is interrupting
|
||||
# Note: self.voice_manager IS the VoiceSession, not the VoiceManager singleton
|
||||
miku_speaking = self.voice_manager.miku_speaking
|
||||
logger.debug(f"[INTERRUPTION CHECK] user={user_id}, miku_speaking={miku_speaking}")
|
||||
|
||||
if miku_speaking:
|
||||
# Track interruption
|
||||
if user_id not in self.interruption_start_time:
|
||||
# First chunk during Miku's speech
|
||||
self.interruption_start_time[user_id] = current_time
|
||||
self.interruption_audio_count[user_id] = 1
|
||||
else:
|
||||
# Increment chunk count
|
||||
self.interruption_audio_count[user_id] += 1
|
||||
|
||||
# Calculate interruption duration
|
||||
interruption_duration = current_time - self.interruption_start_time[user_id]
|
||||
chunk_count = self.interruption_audio_count[user_id]
|
||||
|
||||
# Check if interruption threshold is met
|
||||
if (interruption_duration >= self.interruption_threshold_time and
|
||||
chunk_count >= self.interruption_threshold_chunks):
|
||||
|
||||
# Trigger interruption!
|
||||
logger.info(f"🛑 User {user_id} interrupted Miku (duration={interruption_duration:.2f}s, chunks={chunk_count})")
|
||||
logger.info(f" → Stopping Miku's TTS and LLM, will process user's speech when finished")
|
||||
|
||||
# Reset interruption tracking
|
||||
self.interruption_start_time.pop(user_id, None)
|
||||
self.interruption_audio_count.pop(user_id, None)
|
||||
|
||||
# Call interruption handler (this sets miku_speaking=False)
|
||||
asyncio.create_task(
|
||||
self.voice_manager.on_user_interruption(user_id)
|
||||
)
|
||||
else:
|
||||
# Miku not speaking, clear interruption tracking
|
||||
self.interruption_start_time.pop(user_id, None)
|
||||
self.interruption_audio_count.pop(user_id, None)
|
||||
|
||||
# Cancel existing silence task if any
|
||||
if user_id in self.silence_tasks and not self.silence_tasks[user_id].done():
|
||||
self.silence_tasks[user_id].cancel()
|
||||
|
||||
# Start new silence detection task
|
||||
self.silence_tasks[user_id] = asyncio.create_task(
|
||||
self._detect_silence(user_id)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send audio chunk for user {user_id}: {e}")
|
||||
|
||||
async def _detect_silence(self, user_id: int):
|
||||
"""
|
||||
Wait for silence timeout and send 'final' command to STT.
|
||||
|
||||
This is called after each audio chunk. If no more audio arrives within
|
||||
the silence_timeout period, we send the 'final' command to get the
|
||||
complete transcription.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
"""
|
||||
try:
|
||||
# Wait for silence timeout
|
||||
await asyncio.sleep(self.silence_timeout)
|
||||
|
||||
# Check if we still have an active STT client
|
||||
stt_client = self.stt_clients.get(user_id)
|
||||
if not stt_client or not stt_client.is_connected():
|
||||
return
|
||||
|
||||
# Send final command to get complete transcription
|
||||
logger.debug(f"Silence detected for user {user_id}, requesting final transcript")
|
||||
await stt_client.send_final()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Task was cancelled because new audio arrived
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error in silence detection for user {user_id}: {e}")
|
||||
|
||||
async def _on_vad_event(self, user_id: int, event: dict):
|
||||
"""
|
||||
Handle VAD event from STT.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
event: VAD event dictionary with 'event' and 'probability' keys
|
||||
"""
|
||||
user = self.users.get(user_id)
|
||||
event_type = event.get('event', 'unknown')
|
||||
probability = event.get('probability', 0.0)
|
||||
|
||||
logger.debug(f"VAD [{user.name if user else user_id}]: {event_type} (prob={probability:.3f})")
|
||||
|
||||
# Notify voice manager - pass the full event dict
|
||||
if hasattr(self.voice_manager, 'on_user_vad_event'):
|
||||
await self.voice_manager.on_user_vad_event(user_id, event)
|
||||
|
||||
async def _on_partial_transcript(self, user_id: int, text: str):
|
||||
"""
|
||||
Handle partial transcript from STT.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
text: Partial transcript text
|
||||
"""
|
||||
user = self.users.get(user_id)
|
||||
logger.info(f"[VOICE_RECEIVER] Partial [{user.name if user else user_id}]: {text}")
|
||||
print(f"[DEBUG] PARTIAL TRANSCRIPT RECEIVED: {text}") # Extra debug
|
||||
|
||||
# Notify voice manager
|
||||
if hasattr(self.voice_manager, 'on_partial_transcript'):
|
||||
await self.voice_manager.on_partial_transcript(user_id, text)
|
||||
|
||||
async def _on_final_transcript(self, user_id: int, text: str, user: discord.User):
|
||||
"""
|
||||
Handle final transcript from STT.
|
||||
|
||||
This triggers the LLM response generation.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
text: Final transcript text
|
||||
user: Discord user object
|
||||
"""
|
||||
logger.info(f"[VOICE_RECEIVER] Final [{user.name if user else user_id}]: {text}")
|
||||
print(f"[DEBUG] FINAL TRANSCRIPT RECEIVED: {text}") # Extra debug
|
||||
|
||||
# Notify voice manager - THIS TRIGGERS LLM RESPONSE
|
||||
if hasattr(self.voice_manager, 'on_final_transcript'):
|
||||
await self.voice_manager.on_final_transcript(user_id, text)
|
||||
|
||||
async def _on_interruption(self, user_id: int, probability: float):
|
||||
"""
|
||||
Handle interruption detection from STT.
|
||||
|
||||
This cancels Miku's current speech if user interrupts.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
probability: Interruption confidence probability
|
||||
"""
|
||||
user = self.users.get(user_id)
|
||||
logger.info(f"Interruption from [{user.name if user else user_id}] (prob={probability:.3f})")
|
||||
|
||||
# Notify voice manager - THIS CANCELS MIKU'S SPEECH
|
||||
if hasattr(self.voice_manager, 'on_user_interruption'):
|
||||
await self.voice_manager.on_user_interruption(user_id, probability)
|
||||
|
||||
def get_listening_users(self) -> list:
|
||||
"""
|
||||
Get list of users currently being listened to.
|
||||
|
||||
Returns:
|
||||
List of dicts with user_id, username, and connection status
|
||||
"""
|
||||
return [
|
||||
{
|
||||
'user_id': user_id,
|
||||
'username': user.name if user else 'Unknown',
|
||||
'connected': client.is_connected()
|
||||
}
|
||||
for user_id, (user, client) in
|
||||
[(uid, (self.users.get(uid), self.stt_clients.get(uid)))
|
||||
for uid in self.stt_clients.keys()]
|
||||
]
|
||||
|
||||
@voice_recv.AudioSink.listener()
|
||||
def on_voice_member_speaking_start(self, member: discord.Member):
|
||||
"""
|
||||
Called when a member starts speaking (green circle appears).
|
||||
|
||||
This is a virtual event from discord-ext-voice-recv based on packet activity.
|
||||
"""
|
||||
if member.id in self.stt_clients:
|
||||
logger.debug(f"🎤 {member.name} started speaking")
|
||||
|
||||
@voice_recv.AudioSink.listener()
|
||||
def on_voice_member_speaking_stop(self, member: discord.Member):
|
||||
"""
|
||||
Called when a member stops speaking (green circle disappears).
|
||||
|
||||
This is a virtual event from discord-ext-voice-recv based on packet activity.
|
||||
"""
|
||||
if member.id in self.stt_clients:
|
||||
logger.debug(f"🔇 {member.name} stopped speaking")
|
||||
Reference in New Issue
Block a user