Decided on Parakeet ONNX Runtime. Works pretty great. Realtime voice chat possible now. UX lacking.
This commit is contained in:
@@ -27,7 +27,7 @@ class STTClient:
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
stt_url: str = "ws://miku-stt:8000/ws/stt",
|
||||
stt_url: str = "ws://miku-stt:8766/ws/stt",
|
||||
on_vad_event: Optional[Callable] = None,
|
||||
on_partial_transcript: Optional[Callable] = None,
|
||||
on_final_transcript: Optional[Callable] = None,
|
||||
@@ -140,6 +140,44 @@ class STTClient:
|
||||
logger.error(f"Failed to send audio to STT: {e}")
|
||||
self.connected = False
|
||||
|
||||
async def send_final(self):
|
||||
"""
|
||||
Request final transcription from STT server.
|
||||
|
||||
Call this when the user stops speaking to get the final transcript.
|
||||
"""
|
||||
if not self.connected or not self.websocket:
|
||||
logger.warning(f"Cannot send final command, not connected for user {self.user_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
command = json.dumps({"type": "final"})
|
||||
await self.websocket.send_str(command)
|
||||
logger.debug(f"Sent final command to STT")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send final command to STT: {e}")
|
||||
self.connected = False
|
||||
|
||||
async def send_reset(self):
|
||||
"""
|
||||
Reset the STT server's audio buffer.
|
||||
|
||||
Call this to clear any buffered audio.
|
||||
"""
|
||||
if not self.connected or not self.websocket:
|
||||
logger.warning(f"Cannot send reset command, not connected for user {self.user_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
command = json.dumps({"type": "reset"})
|
||||
await self.websocket.send_str(command)
|
||||
logger.debug(f"Sent reset command to STT")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send reset command to STT: {e}")
|
||||
self.connected = False
|
||||
|
||||
async def _receive_events(self):
|
||||
"""Background task to receive events from STT server."""
|
||||
try:
|
||||
@@ -177,14 +215,29 @@ class STTClient:
|
||||
"""
|
||||
event_type = event.get('type')
|
||||
|
||||
if event_type == 'vad':
|
||||
# VAD event: speech detection
|
||||
if event_type == 'transcript':
|
||||
# New ONNX server protocol: single transcript type with is_final flag
|
||||
text = event.get('text', '')
|
||||
is_final = event.get('is_final', False)
|
||||
timestamp = event.get('timestamp', 0)
|
||||
|
||||
if is_final:
|
||||
logger.info(f"Final transcript [{self.user_id}]: {text}")
|
||||
if self.on_final_transcript:
|
||||
await self.on_final_transcript(text, timestamp)
|
||||
else:
|
||||
logger.info(f"Partial transcript [{self.user_id}]: {text}")
|
||||
if self.on_partial_transcript:
|
||||
await self.on_partial_transcript(text, timestamp)
|
||||
|
||||
elif event_type == 'vad':
|
||||
# VAD event: speech detection (legacy support)
|
||||
logger.debug(f"VAD event: {event}")
|
||||
if self.on_vad_event:
|
||||
await self.on_vad_event(event)
|
||||
|
||||
elif event_type == 'partial':
|
||||
# Partial transcript
|
||||
# Legacy protocol support: partial transcript
|
||||
text = event.get('text', '')
|
||||
timestamp = event.get('timestamp', 0)
|
||||
logger.info(f"Partial transcript [{self.user_id}]: {text}")
|
||||
@@ -192,7 +245,7 @@ class STTClient:
|
||||
await self.on_partial_transcript(text, timestamp)
|
||||
|
||||
elif event_type == 'final':
|
||||
# Final transcript
|
||||
# Legacy protocol support: final transcript
|
||||
text = event.get('text', '')
|
||||
timestamp = event.get('timestamp', 0)
|
||||
logger.info(f"Final transcript [{self.user_id}]: {text}")
|
||||
@@ -200,12 +253,20 @@ class STTClient:
|
||||
await self.on_final_transcript(text, timestamp)
|
||||
|
||||
elif event_type == 'interruption':
|
||||
# Interruption detected
|
||||
# Interruption detected (legacy support)
|
||||
probability = event.get('probability', 0)
|
||||
logger.info(f"Interruption detected from user {self.user_id} (prob={probability:.3f})")
|
||||
if self.on_interruption:
|
||||
await self.on_interruption(probability)
|
||||
|
||||
elif event_type == 'info':
|
||||
# Info message
|
||||
logger.info(f"STT info: {event.get('message', '')}")
|
||||
|
||||
elif event_type == 'error':
|
||||
# Error message
|
||||
logger.error(f"STT error: {event.get('message', '')}")
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown STT event type: {event_type}")
|
||||
|
||||
|
||||
@@ -293,6 +293,15 @@ class MikuVoiceSource(discord.AudioSource):
|
||||
logger.debug("Sent flush command to TTS")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send flush command: {e}")
|
||||
|
||||
async def clear_buffer(self):
|
||||
"""
|
||||
Clear the audio buffer without disconnecting.
|
||||
Used when interrupting playback to avoid playing old audio.
|
||||
"""
|
||||
async with self.buffer_lock:
|
||||
self.audio_buffer.clear()
|
||||
logger.debug("Audio buffer cleared")
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -391,6 +391,12 @@ class VoiceSession:
|
||||
self.voice_receiver: Optional['VoiceReceiver'] = None # STT receiver
|
||||
self.active = False
|
||||
self.miku_speaking = False # Track if Miku is currently speaking
|
||||
self.llm_stream_task: Optional[asyncio.Task] = None # Track LLM streaming task for cancellation
|
||||
self.last_interruption_time: float = 0 # Track when last interruption occurred
|
||||
self.interruption_silence_duration = 0.8 # Seconds of silence after interruption before next response
|
||||
|
||||
# Voice chat conversation history (last 8 exchanges)
|
||||
self.conversation_history = [] # List of {"role": "user"/"assistant", "content": str}
|
||||
|
||||
logger.info(f"VoiceSession created for {voice_channel.name} in guild {guild_id}")
|
||||
|
||||
@@ -496,8 +502,23 @@ class VoiceSession:
|
||||
"""
|
||||
Called when final transcript is received.
|
||||
This triggers LLM response and TTS.
|
||||
|
||||
Note: If user interrupted Miku, miku_speaking will already be False
|
||||
by the time this is called, so the response will proceed normally.
|
||||
"""
|
||||
logger.info(f"Final from user {user_id}: {text}")
|
||||
logger.info(f"📝 Final transcript from user {user_id}: {text}")
|
||||
|
||||
# Check if Miku is STILL speaking (not interrupted)
|
||||
# This prevents queueing if user speaks briefly but not long enough to interrupt
|
||||
if self.miku_speaking:
|
||||
logger.info(f"⏭️ Ignoring short input while Miku is speaking (user didn't interrupt long enough)")
|
||||
# Get user info for notification
|
||||
user = self.voice_channel.guild.get_member(user_id)
|
||||
user_name = user.name if user else f"User {user_id}"
|
||||
await self.text_channel.send(f"💬 *{user_name} said: \"{text}\" (interrupted but too brief - talk longer to interrupt)*")
|
||||
return
|
||||
|
||||
logger.info(f"✓ Processing final transcript (miku_speaking={self.miku_speaking})")
|
||||
|
||||
# Get user info
|
||||
user = self.voice_channel.guild.get_member(user_id)
|
||||
@@ -505,26 +526,79 @@ class VoiceSession:
|
||||
logger.warning(f"User {user_id} not found in guild")
|
||||
return
|
||||
|
||||
# Check for stop commands (don't generate response if user wants silence)
|
||||
stop_phrases = ["stop talking", "be quiet", "shut up", "stop speaking", "silence"]
|
||||
if any(phrase in text.lower() for phrase in stop_phrases):
|
||||
logger.info(f"🤫 Stop command detected: {text}")
|
||||
await self.text_channel.send(f"🎤 {user.name}: *\"{text}\"*")
|
||||
await self.text_channel.send(f"🤫 *Miku goes quiet*")
|
||||
return
|
||||
|
||||
# Show what user said
|
||||
await self.text_channel.send(f"🎤 {user.name}: *\"{text}\"*")
|
||||
|
||||
# Generate LLM response and speak it
|
||||
await self._generate_voice_response(user, text)
|
||||
|
||||
async def on_user_interruption(self, user_id: int, probability: float):
|
||||
async def on_user_interruption(self, user_id: int):
|
||||
"""
|
||||
Called when user interrupts Miku's speech.
|
||||
Cancel TTS and switch to listening.
|
||||
|
||||
This is triggered when user speaks over Miku for long enough (0.8s+ with 8+ chunks).
|
||||
Immediately cancels LLM streaming, TTS synthesis, and clears audio buffers.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID who interrupted
|
||||
"""
|
||||
if not self.miku_speaking:
|
||||
return
|
||||
|
||||
logger.info(f"User {user_id} interrupted Miku (prob={probability:.3f})")
|
||||
logger.info(f"🛑 User {user_id} interrupted Miku - canceling everything immediately")
|
||||
|
||||
# Cancel Miku's speech
|
||||
# Get user info
|
||||
user = self.voice_channel.guild.get_member(user_id)
|
||||
user_name = user.name if user else f"User {user_id}"
|
||||
|
||||
# 1. Mark that Miku is no longer speaking (stops LLM streaming loop check)
|
||||
self.miku_speaking = False
|
||||
|
||||
# 2. Cancel LLM streaming task if it's running
|
||||
if self.llm_stream_task and not self.llm_stream_task.done():
|
||||
self.llm_stream_task.cancel()
|
||||
try:
|
||||
await self.llm_stream_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info("✓ LLM streaming task cancelled")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling LLM task: {e}")
|
||||
|
||||
# 3. Cancel TTS/RVC synthesis and playback
|
||||
await self._cancel_tts()
|
||||
|
||||
# 4. Add a brief pause to create audible separation
|
||||
# This gives a fade-out effect and makes the interruption less jarring
|
||||
import time
|
||||
self.last_interruption_time = time.time()
|
||||
logger.info(f"⏸️ Pausing for {self.interruption_silence_duration}s after interruption")
|
||||
await asyncio.sleep(self.interruption_silence_duration)
|
||||
|
||||
# 5. Add interruption marker to conversation history
|
||||
self.conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": "[INTERRUPTED - user started speaking]"
|
||||
})
|
||||
|
||||
# Show interruption in chat
|
||||
await self.text_channel.send(f"⚠️ *{user_name} interrupted Miku*")
|
||||
|
||||
logger.info(f"✓ Interruption handled, ready for next input")
|
||||
|
||||
async def on_user_interruption_old(self, user_id: int, probability: float):
|
||||
"""
|
||||
Legacy interruption handler (kept for compatibility).
|
||||
Called when VAD-based interruption detection is used.
|
||||
"""
|
||||
await self.on_user_interruption(user_id)
|
||||
user = self.voice_channel.guild.get_member(user_id)
|
||||
await self.text_channel.send(f"⚠️ *{user.name if user else 'User'} interrupted Miku*")
|
||||
|
||||
@@ -537,7 +611,18 @@ class VoiceSession:
|
||||
text: Transcribed text
|
||||
"""
|
||||
try:
|
||||
# Check if we need to wait due to recent interruption
|
||||
import time
|
||||
if self.last_interruption_time > 0:
|
||||
time_since_interruption = time.time() - self.last_interruption_time
|
||||
remaining_pause = self.interruption_silence_duration - time_since_interruption
|
||||
if remaining_pause > 0:
|
||||
logger.info(f"⏸️ Waiting {remaining_pause:.2f}s more before responding (interruption cooldown)")
|
||||
await asyncio.sleep(remaining_pause)
|
||||
|
||||
logger.info(f"🎙️ Starting voice response generation (setting miku_speaking=True)")
|
||||
self.miku_speaking = True
|
||||
logger.info(f" → miku_speaking is now: {self.miku_speaking}")
|
||||
|
||||
# Show processing
|
||||
await self.text_channel.send(f"💭 *Miku is thinking...*")
|
||||
@@ -547,17 +632,53 @@ class VoiceSession:
|
||||
import aiohttp
|
||||
import globals
|
||||
|
||||
# Simple system prompt for voice
|
||||
system_prompt = """You are Hatsune Miku, the virtual singer.
|
||||
Respond naturally and concisely as Miku would in a voice conversation.
|
||||
Keep responses short (1-3 sentences) since they will be spoken aloud."""
|
||||
# Load personality and lore
|
||||
miku_lore = ""
|
||||
miku_prompt = ""
|
||||
try:
|
||||
with open('/app/miku_lore.txt', 'r', encoding='utf-8') as f:
|
||||
miku_lore = f.read().strip()
|
||||
with open('/app/miku_prompt.txt', 'r', encoding='utf-8') as f:
|
||||
miku_prompt = f.read().strip()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load personality files: {e}")
|
||||
|
||||
# Build voice chat system prompt
|
||||
system_prompt = f"""{miku_prompt}
|
||||
|
||||
{miku_lore}
|
||||
|
||||
VOICE CHAT CONTEXT:
|
||||
- You are currently in a voice channel speaking with {user.name} and others
|
||||
- Your responses will be spoken aloud via text-to-speech
|
||||
- Keep responses natural and conversational - vary your length based on context:
|
||||
* Quick reactions: 1 sentence ("Oh wow!" or "That's amazing!")
|
||||
* Normal chat: 2-3 sentences (share a thought or feeling)
|
||||
* Stories/explanations: 4-6 sentences when asked for details
|
||||
- Match the user's energy and conversation style
|
||||
- IMPORTANT: Only respond in ENGLISH! The TTS system cannot handle Japanese or other languages well.
|
||||
- Be expressive and use casual language, but stay in character as Miku
|
||||
- If user says "stop talking" or "be quiet", acknowledge briefly and stop
|
||||
|
||||
Remember: This is a live voice conversation - be natural, not formulaic!"""
|
||||
|
||||
# Add user message to history
|
||||
self.conversation_history.append({
|
||||
"role": "user",
|
||||
"content": f"{user.name}: {text}"
|
||||
})
|
||||
|
||||
# Keep only last 8 exchanges (16 messages = 8 user + 8 assistant)
|
||||
if len(self.conversation_history) > 16:
|
||||
self.conversation_history = self.conversation_history[-16:]
|
||||
|
||||
# Build messages for LLM
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
messages.extend(self.conversation_history)
|
||||
|
||||
payload = {
|
||||
"model": globals.TEXT_MODEL,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": text}
|
||||
],
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
"temperature": 0.8,
|
||||
"max_tokens": 200
|
||||
@@ -566,50 +687,74 @@ Keep responses short (1-3 sentences) since they will be spoken aloud."""
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
llama_url = get_current_gpu_url()
|
||||
|
||||
# Stream LLM response to TTS
|
||||
full_response = ""
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
async with http_session.post(
|
||||
f"{llama_url}/v1/chat/completions",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=60)
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"LLM error {response.status}: {error_text}")
|
||||
|
||||
# Stream tokens to TTS
|
||||
async for line in response.content:
|
||||
if not self.miku_speaking:
|
||||
# Interrupted
|
||||
break
|
||||
# Create streaming task so we can cancel it if interrupted
|
||||
async def stream_llm_to_tts():
|
||||
"""Stream LLM tokens to TTS. Can be cancelled."""
|
||||
full_response = ""
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
async with http_session.post(
|
||||
f"{llama_url}/v1/chat/completions",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=60)
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"LLM error {response.status}: {error_text}")
|
||||
|
||||
line = line.decode('utf-8').strip()
|
||||
if line.startswith('data: '):
|
||||
data_str = line[6:]
|
||||
if data_str == '[DONE]':
|
||||
# Stream tokens to TTS
|
||||
async for line in response.content:
|
||||
if not self.miku_speaking:
|
||||
# Interrupted - exit gracefully
|
||||
logger.info("🛑 LLM streaming stopped (miku_speaking=False)")
|
||||
break
|
||||
|
||||
try:
|
||||
import json
|
||||
data = json.loads(data_str)
|
||||
if 'choices' in data and len(data['choices']) > 0:
|
||||
delta = data['choices'][0].get('delta', {})
|
||||
content = delta.get('content', '')
|
||||
if content:
|
||||
await self.audio_source.send_token(content)
|
||||
full_response += content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
line = line.decode('utf-8').strip()
|
||||
if line.startswith('data: '):
|
||||
data_str = line[6:]
|
||||
if data_str == '[DONE]':
|
||||
break
|
||||
|
||||
try:
|
||||
import json
|
||||
data = json.loads(data_str)
|
||||
if 'choices' in data and len(data['choices']) > 0:
|
||||
delta = data['choices'][0].get('delta', {})
|
||||
content = delta.get('content', '')
|
||||
if content:
|
||||
await self.audio_source.send_token(content)
|
||||
full_response += content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return full_response
|
||||
|
||||
# Run streaming as a task that can be cancelled
|
||||
self.llm_stream_task = asyncio.create_task(stream_llm_to_tts())
|
||||
|
||||
try:
|
||||
full_response = await self.llm_stream_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info("✓ LLM streaming cancelled by interruption")
|
||||
# Don't re-raise - just return early to avoid breaking STT client
|
||||
return
|
||||
|
||||
# Flush TTS
|
||||
if self.miku_speaking:
|
||||
await self.audio_source.flush()
|
||||
|
||||
# Add Miku's complete response to history
|
||||
self.conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": full_response.strip()
|
||||
})
|
||||
|
||||
# Show response
|
||||
await self.text_channel.send(f"🎤 Miku: *\"{full_response.strip()}\"*")
|
||||
logger.info(f"✓ Voice response complete: {full_response.strip()}")
|
||||
else:
|
||||
# Interrupted - don't add incomplete response to history
|
||||
# (interruption marker already added by on_user_interruption)
|
||||
logger.info(f"✓ Response interrupted after {len(full_response)} chars")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Voice response failed: {e}", exc_info=True)
|
||||
@@ -619,24 +764,50 @@ Keep responses short (1-3 sentences) since they will be spoken aloud."""
|
||||
self.miku_speaking = False
|
||||
|
||||
async def _cancel_tts(self):
|
||||
"""Cancel current TTS synthesis."""
|
||||
logger.info("Canceling TTS synthesis")
|
||||
"""
|
||||
Immediately cancel TTS synthesis and clear all audio buffers.
|
||||
|
||||
# Stop Discord playback
|
||||
if self.voice_client and self.voice_client.is_playing():
|
||||
self.voice_client.stop()
|
||||
This sends interrupt signals to:
|
||||
1. Local audio buffer (clears queued audio)
|
||||
2. RVC TTS server (stops synthesis pipeline)
|
||||
|
||||
# Send interrupt to RVC
|
||||
Does NOT stop voice_client (that would disconnect voice receiver).
|
||||
"""
|
||||
logger.info("🛑 Canceling TTS synthesis immediately")
|
||||
|
||||
# 1. FIRST: Clear local audio buffer to stop playing queued audio
|
||||
if self.audio_source:
|
||||
try:
|
||||
await self.audio_source.clear_buffer()
|
||||
logger.info("✓ Audio buffer cleared")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear audio buffer: {e}")
|
||||
|
||||
# 2. SECOND: Send interrupt to RVC to stop synthesis pipeline
|
||||
try:
|
||||
import aiohttp
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post("http://172.25.0.1:8765/interrupt") as resp:
|
||||
if resp.status == 200:
|
||||
logger.info("✓ TTS interrupted")
|
||||
# Send interrupt multiple times rapidly to ensure it's received
|
||||
for i in range(3):
|
||||
try:
|
||||
async with session.post(
|
||||
"http://172.25.0.1:8765/interrupt",
|
||||
timeout=aiohttp.ClientTimeout(total=2.0)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
logger.info(f"✓ TTS interrupted (flushed {data.get('zmq_chunks_flushed', 0)} chunks)")
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
if i < 2: # Don't warn on last attempt
|
||||
logger.warning("Interrupt request timed out, retrying...")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to interrupt TTS: {e}")
|
||||
|
||||
self.miku_speaking = False
|
||||
# Note: We do NOT call voice_client.stop() because that would
|
||||
# stop the entire voice system including the receiver!
|
||||
# The audio source will just play silence until new tokens arrive.
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
|
||||
@@ -27,13 +27,13 @@ class VoiceReceiverSink(voice_recv.AudioSink):
|
||||
decodes/resamples as needed, and sends to STT clients for transcription.
|
||||
"""
|
||||
|
||||
def __init__(self, voice_manager, stt_url: str = "ws://miku-stt:8000/ws/stt"):
|
||||
def __init__(self, voice_manager, stt_url: str = "ws://miku-stt:8766/ws/stt"):
|
||||
"""
|
||||
Initialize voice receiver sink.
|
||||
Initialize Voice Receiver.
|
||||
|
||||
Args:
|
||||
voice_manager: Reference to VoiceManager for callbacks
|
||||
stt_url: Base URL for STT WebSocket server with path (port 8000 inside container)
|
||||
voice_manager: The voice manager instance
|
||||
stt_url: Base URL for STT WebSocket server with path (port 8766 inside container)
|
||||
"""
|
||||
super().__init__()
|
||||
self.voice_manager = voice_manager
|
||||
@@ -56,6 +56,17 @@ class VoiceReceiverSink(voice_recv.AudioSink):
|
||||
# User info (for logging)
|
||||
self.users: Dict[int, discord.User] = {}
|
||||
|
||||
# Silence tracking for detecting end of speech
|
||||
self.last_audio_time: Dict[int, float] = {}
|
||||
self.silence_tasks: Dict[int, asyncio.Task] = {}
|
||||
self.silence_timeout = 1.0 # seconds of silence before sending "final"
|
||||
|
||||
# Interruption detection
|
||||
self.interruption_start_time: Dict[int, float] = {}
|
||||
self.interruption_audio_count: Dict[int, int] = {}
|
||||
self.interruption_threshold_time = 0.8 # seconds of speech to count as interruption
|
||||
self.interruption_threshold_chunks = 8 # minimum audio chunks to count as interruption
|
||||
|
||||
# Active flag
|
||||
self.active = False
|
||||
|
||||
@@ -232,6 +243,17 @@ class VoiceReceiverSink(voice_recv.AudioSink):
|
||||
if user_id in self.users:
|
||||
del self.users[user_id]
|
||||
|
||||
# Cancel silence detection task
|
||||
if user_id in self.silence_tasks and not self.silence_tasks[user_id].done():
|
||||
self.silence_tasks[user_id].cancel()
|
||||
del self.silence_tasks[user_id]
|
||||
if user_id in self.last_audio_time:
|
||||
del self.last_audio_time[user_id]
|
||||
|
||||
# Clear interruption tracking
|
||||
self.interruption_start_time.pop(user_id, None)
|
||||
self.interruption_audio_count.pop(user_id, None)
|
||||
|
||||
# Cleanup opus decoder for this user
|
||||
if hasattr(self, '_opus_decoders') and user_id in self._opus_decoders:
|
||||
del self._opus_decoders[user_id]
|
||||
@@ -299,10 +321,95 @@ class VoiceReceiverSink(voice_recv.AudioSink):
|
||||
else:
|
||||
# Put remaining partial chunk back in buffer
|
||||
buffer.append(chunk)
|
||||
|
||||
# Track audio time for silence detection
|
||||
import time
|
||||
current_time = time.time()
|
||||
self.last_audio_time[user_id] = current_time
|
||||
|
||||
# ===== INTERRUPTION DETECTION =====
|
||||
# Check if Miku is speaking and user is interrupting
|
||||
# Note: self.voice_manager IS the VoiceSession, not the VoiceManager singleton
|
||||
miku_speaking = self.voice_manager.miku_speaking
|
||||
logger.debug(f"[INTERRUPTION CHECK] user={user_id}, miku_speaking={miku_speaking}")
|
||||
|
||||
if miku_speaking:
|
||||
# Track interruption
|
||||
if user_id not in self.interruption_start_time:
|
||||
# First chunk during Miku's speech
|
||||
self.interruption_start_time[user_id] = current_time
|
||||
self.interruption_audio_count[user_id] = 1
|
||||
else:
|
||||
# Increment chunk count
|
||||
self.interruption_audio_count[user_id] += 1
|
||||
|
||||
# Calculate interruption duration
|
||||
interruption_duration = current_time - self.interruption_start_time[user_id]
|
||||
chunk_count = self.interruption_audio_count[user_id]
|
||||
|
||||
# Check if interruption threshold is met
|
||||
if (interruption_duration >= self.interruption_threshold_time and
|
||||
chunk_count >= self.interruption_threshold_chunks):
|
||||
|
||||
# Trigger interruption!
|
||||
logger.info(f"🛑 User {user_id} interrupted Miku (duration={interruption_duration:.2f}s, chunks={chunk_count})")
|
||||
logger.info(f" → Stopping Miku's TTS and LLM, will process user's speech when finished")
|
||||
|
||||
# Reset interruption tracking
|
||||
self.interruption_start_time.pop(user_id, None)
|
||||
self.interruption_audio_count.pop(user_id, None)
|
||||
|
||||
# Call interruption handler (this sets miku_speaking=False)
|
||||
asyncio.create_task(
|
||||
self.voice_manager.on_user_interruption(user_id)
|
||||
)
|
||||
else:
|
||||
# Miku not speaking, clear interruption tracking
|
||||
self.interruption_start_time.pop(user_id, None)
|
||||
self.interruption_audio_count.pop(user_id, None)
|
||||
|
||||
# Cancel existing silence task if any
|
||||
if user_id in self.silence_tasks and not self.silence_tasks[user_id].done():
|
||||
self.silence_tasks[user_id].cancel()
|
||||
|
||||
# Start new silence detection task
|
||||
self.silence_tasks[user_id] = asyncio.create_task(
|
||||
self._detect_silence(user_id)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send audio chunk for user {user_id}: {e}")
|
||||
|
||||
async def _detect_silence(self, user_id: int):
|
||||
"""
|
||||
Wait for silence timeout and send 'final' command to STT.
|
||||
|
||||
This is called after each audio chunk. If no more audio arrives within
|
||||
the silence_timeout period, we send the 'final' command to get the
|
||||
complete transcription.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
"""
|
||||
try:
|
||||
# Wait for silence timeout
|
||||
await asyncio.sleep(self.silence_timeout)
|
||||
|
||||
# Check if we still have an active STT client
|
||||
stt_client = self.stt_clients.get(user_id)
|
||||
if not stt_client or not stt_client.is_connected():
|
||||
return
|
||||
|
||||
# Send final command to get complete transcription
|
||||
logger.debug(f"Silence detected for user {user_id}, requesting final transcript")
|
||||
await stt_client.send_final()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Task was cancelled because new audio arrived
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error in silence detection for user {user_id}: {e}")
|
||||
|
||||
async def _on_vad_event(self, user_id: int, event: dict):
|
||||
"""
|
||||
Handle VAD event from STT.
|
||||
|
||||
Reference in New Issue
Block a user