Files
miku-discord/bot/utils/voice_receiver.py

489 lines
19 KiB
Python

"""
Discord Voice Receiver using discord-ext-voice-recv
Captures audio from Discord voice channels and streams to STT.
Uses the discord-ext-voice-recv extension for proper audio receiving support.
"""
import asyncio
import audioop
import logging
import struct
import array
from typing import Dict, Optional
from collections import deque
import discord
from discord.ext import voice_recv
from utils.stt_client import STTClient
logger = logging.getLogger('voice_receiver')
class VoiceReceiverSink(voice_recv.AudioSink):
"""
Audio sink that receives Discord audio and forwards to STT.
This sink processes incoming audio from Discord voice channels,
decodes/resamples as needed, and sends to STT clients for transcription.
"""
def __init__(self, voice_manager, stt_url: str = "ws://miku-stt:8766"):
"""
Initialize Voice Receiver.
Args:
voice_manager: The voice manager instance
stt_url: WebSocket URL for RealtimeSTT server (port 8766 inside container)
"""
super().__init__()
self.voice_manager = voice_manager
self.stt_url = stt_url
# Store event loop for thread-safe async calls
# Use get_running_loop() in async context, or store it when available
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
# Fallback if not in async context yet
self.loop = asyncio.get_event_loop()
# Per-user STT clients
self.stt_clients: Dict[int, STTClient] = {}
# Audio buffers per user (for resampling state)
self.audio_buffers: Dict[int, deque] = {}
# User info (for logging)
self.users: Dict[int, discord.User] = {}
# Silence tracking for detecting end of speech
self.last_audio_time: Dict[int, float] = {}
self.silence_tasks: Dict[int, asyncio.Task] = {}
self.silence_timeout = 1.0 # seconds of silence before sending "final"
# Interruption detection
self.interruption_start_time: Dict[int, float] = {}
self.interruption_audio_count: Dict[int, int] = {}
self.interruption_threshold_time = 0.8 # seconds of speech to count as interruption
self.interruption_threshold_chunks = 8 # minimum audio chunks to count as interruption
# Active flag
self.active = False
logger.info("VoiceReceiverSink initialized")
@staticmethod
def _preprocess_audio(pcm_data: bytes) -> bytes:
"""
Preprocess audio for better STT accuracy.
Applies:
1. DC offset removal
2. High-pass filter (80Hz) to remove rumble
3. RMS normalization
Args:
pcm_data: Raw PCM audio (16-bit mono, 16kHz)
Returns:
Preprocessed PCM audio
"""
try:
# Convert bytes to array of int16 samples
samples = array.array('h', pcm_data)
# 1. Remove DC offset (mean)
mean = sum(samples) / len(samples) if samples else 0
samples = array.array('h', [int(s - mean) for s in samples])
# 2. Simple high-pass filter (80Hz @ 16kHz)
# Using a simple first-order HPF: y[n] = x[n] - x[n-1] + 0.95 * y[n-1]
alpha = 0.95 # Filter coefficient (roughly 80Hz cutoff at 16kHz)
filtered = array.array('h')
prev_input = 0
prev_output = 0
for sample in samples:
output = sample - prev_input + alpha * prev_output
filtered.append(int(max(-32768, min(32767, output)))) # Clamp to int16 range
prev_input = sample
prev_output = output
# 3. RMS normalization to target level
# Calculate RMS
sum_squares = sum(s * s for s in filtered)
rms = (sum_squares / len(filtered)) ** 0.5 if filtered else 1.0
# Target RMS (roughly -20dB)
target_rms = 3276.8 # 10% of max int16 range
# Normalize if RMS is too low or too high
if rms > 100: # Only normalize if there's actual signal
gain = target_rms / rms
# Limit gain to prevent over-amplification of noise
gain = min(gain, 4.0) # Max 12dB boost
normalized = array.array('h', [
int(max(-32768, min(32767, s * gain))) for s in filtered
])
return normalized.tobytes()
else:
# Signal too weak, return filtered without normalization
return filtered.tobytes()
except Exception as e:
logger.debug(f"Audio preprocessing failed, using raw audio: {e}")
return pcm_data
def wants_opus(self) -> bool:
"""
Tell discord-ext-voice-recv we want Opus data, NOT decoded PCM.
We'll decode it ourselves to avoid decoder errors from discord-ext-voice-recv.
Returns:
True - we want Opus packets, we'll handle decoding
"""
return True # Get Opus, decode ourselves to avoid packet router errors
def write(self, user: Optional[discord.User], data: voice_recv.VoiceData):
"""
Called by discord-ext-voice-recv when audio is received.
This is the main callback that receives audio packets from Discord.
We get Opus data, decode it ourselves, resample, and forward to STT.
Args:
user: Discord user who sent the audio (None if unknown)
data: Voice data container with pcm, opus, and packet info
"""
if not user:
return # Skip packets from unknown users
user_id = user.id
# Check if we're listening to this user
if user_id not in self.stt_clients:
return
try:
# Get Opus data (we decode ourselves to avoid PacketRouter errors)
opus_data = data.opus
if not opus_data:
return
# Decode Opus to PCM (48kHz stereo int16)
# Use discord.py's opus decoder with proper error handling
import discord.opus
if not hasattr(self, '_opus_decoders'):
self._opus_decoders = {}
# Create decoder for this user if needed
if user_id not in self._opus_decoders:
self._opus_decoders[user_id] = discord.opus.Decoder()
decoder = self._opus_decoders[user_id]
# Decode opus -> PCM (this can fail on corrupt packets, so catch it)
try:
pcm_data = decoder.decode(opus_data, fec=False)
except discord.opus.OpusError as e:
# Skip corrupted packets silently (common at stream start)
logger.debug(f"Skipping corrupted opus packet for user {user_id}: {e}")
return
if not pcm_data:
return
# PCM from Discord is 48kHz stereo int16
# Convert stereo to mono
if len(pcm_data) % 4 == 0: # Stereo (2 channels * 2 bytes per sample)
pcm_mono = audioop.tomono(pcm_data, 2, 0.5, 0.5)
else:
pcm_mono = pcm_data
# Resample from 48kHz to 16kHz for STT
# Discord sends 20ms chunks: 960 samples @ 48kHz → 320 samples @ 16kHz
pcm_16k, _ = audioop.ratecv(pcm_mono, 2, 1, 48000, 16000, None)
# Preprocess audio for better STT accuracy
# (DC offset removal, high-pass filter, RMS normalization)
pcm_16k = self._preprocess_audio(pcm_16k)
# Send to STT client (schedule on event loop thread-safely)
asyncio.run_coroutine_threadsafe(
self._send_audio_chunk(user_id, pcm_16k),
self.loop
)
except Exception as e:
logger.error(f"Error processing audio for user {user_id}: {e}", exc_info=True)
def cleanup(self):
"""
Called when the sink is stopped.
Cleanup any resources.
"""
logger.info("VoiceReceiverSink cleanup")
# Async cleanup handled separately in stop_all()
async def start_listening(self, user_id: int, user: discord.User):
"""
Start listening to a specific user.
Creates an STT client connection for this user and registers callbacks.
Args:
user_id: Discord user ID
user: Discord user object
"""
if user_id in self.stt_clients:
logger.warning(f"Already listening to user {user.name} ({user_id})")
return
logger.info(f"Starting to listen to user {user.name} ({user_id})")
# Store user info
self.users[user_id] = user
# Initialize audio buffer
self.audio_buffers[user_id] = deque(maxlen=1000)
# Create STT client with callbacks
# RealtimeSTT handles VAD internally, so we only need partial/final callbacks
stt_client = STTClient(
user_id=user_id,
stt_url=self.stt_url,
on_partial_transcript=lambda text, timestamp: asyncio.create_task(
self._on_partial_transcript(user_id, text)
),
on_final_transcript=lambda text, timestamp: asyncio.create_task(
self._on_final_transcript(user_id, text, user)
),
)
# Connect to STT server
try:
await stt_client.connect()
self.stt_clients[user_id] = stt_client
self.active = True
logger.info(f"✓ STT connected for user {user.name}")
except Exception as e:
logger.error(f"Failed to connect STT for user {user.name}: {e}", exc_info=True)
# Cleanup partial state
if user_id in self.audio_buffers:
del self.audio_buffers[user_id]
if user_id in self.users:
del self.users[user_id]
raise
async def stop_listening(self, user_id: int):
"""
Stop listening to a specific user.
Disconnects the STT client and cleans up resources for this user.
Args:
user_id: Discord user ID
"""
if user_id not in self.stt_clients:
logger.warning(f"Not listening to user {user_id}")
return
user = self.users.get(user_id)
logger.info(f"Stopping listening to user {user.name if user else user_id}")
# Disconnect STT client
stt_client = self.stt_clients[user_id]
await stt_client.disconnect()
# Cleanup
del self.stt_clients[user_id]
if user_id in self.audio_buffers:
del self.audio_buffers[user_id]
if user_id in self.users:
del self.users[user_id]
# Cancel silence detection task
if user_id in self.silence_tasks and not self.silence_tasks[user_id].done():
self.silence_tasks[user_id].cancel()
del self.silence_tasks[user_id]
if user_id in self.last_audio_time:
del self.last_audio_time[user_id]
# Clear interruption tracking
self.interruption_start_time.pop(user_id, None)
self.interruption_audio_count.pop(user_id, None)
# Cleanup opus decoder for this user
if hasattr(self, '_opus_decoders') and user_id in self._opus_decoders:
del self._opus_decoders[user_id]
# Update active flag
if not self.stt_clients:
self.active = False
logger.info(f"✓ Stopped listening to user {user.name if user else user_id}")
async def stop_all(self):
"""Stop listening to all users and cleanup all resources."""
logger.info("Stopping all voice receivers")
user_ids = list(self.stt_clients.keys())
for user_id in user_ids:
await self.stop_listening(user_id)
self.active = False
logger.info("✓ All voice receivers stopped")
async def _send_audio_chunk(self, user_id: int, audio_data: bytes):
"""
Send audio chunk to STT client.
RealtimeSTT expects 16kHz mono 16-bit PCM audio.
We buffer audio to send larger chunks for efficiency.
VAD and silence detection is handled by RealtimeSTT.
Args:
user_id: Discord user ID
audio_data: PCM audio (int16, 16kHz mono)
"""
stt_client = self.stt_clients.get(user_id)
if not stt_client or not stt_client.connected:
return
try:
# Get or create buffer for this user
if user_id not in self.audio_buffers:
self.audio_buffers[user_id] = deque()
buffer = self.audio_buffers[user_id]
buffer.append(audio_data)
# Buffer and send in larger chunks for efficiency
# RealtimeSTT will handle VAD internally
BYTES_NEEDED = 1024 # 512 samples * 2 bytes
# Check if we have enough buffered audio
total_bytes = sum(len(chunk) for chunk in buffer)
if total_bytes >= BYTES_NEEDED:
# Concatenate buffered chunks
combined = b''.join(buffer)
buffer.clear()
# Send all audio to STT (RealtimeSTT handles VAD internally)
await stt_client.send_audio(combined)
# Track audio time for interruption detection
import time
current_time = time.time()
self.last_audio_time[user_id] = current_time
# ===== INTERRUPTION DETECTION =====
# Check if Miku is speaking and user is interrupting
# Note: self.voice_manager IS the VoiceSession, not the VoiceManager singleton
miku_speaking = self.voice_manager.miku_speaking
if miku_speaking:
# Calculate RMS to detect if user is actually speaking
# (not just silence/background noise)
rms = audioop.rms(combined, 2)
RMS_THRESHOLD = 500 # Adjust threshold - higher = less sensitive
if rms > RMS_THRESHOLD:
# User is actually speaking - track as potential interruption
if user_id not in self.interruption_start_time:
# First chunk during Miku's speech with actual audio
self.interruption_start_time[user_id] = current_time
self.interruption_audio_count[user_id] = 1
logger.debug(f"Potential interruption start (rms={rms})")
else:
# Increment chunk count
self.interruption_audio_count[user_id] += 1
# Calculate interruption duration
interruption_duration = current_time - self.interruption_start_time[user_id]
chunk_count = self.interruption_audio_count[user_id]
# Check if interruption threshold is met
if (interruption_duration >= self.interruption_threshold_time and
chunk_count >= self.interruption_threshold_chunks):
# Trigger interruption!
logger.info(f"🛑 User {user_id} interrupted Miku (duration={interruption_duration:.2f}s, chunks={chunk_count}, rms={rms})")
logger.info(f" → Stopping Miku's TTS and LLM, will process user's speech when finished")
# Reset interruption tracking
self.interruption_start_time.pop(user_id, None)
self.interruption_audio_count.pop(user_id, None)
# Call interruption handler (this sets miku_speaking=False)
asyncio.create_task(
self.voice_manager.on_user_interruption(user_id)
)
else:
# Audio below RMS threshold (silence) - reset interruption tracking
# This ensures brief pauses in speech reset the counter
self.interruption_start_time.pop(user_id, None)
self.interruption_audio_count.pop(user_id, None)
else:
# Miku not speaking, clear interruption tracking
self.interruption_start_time.pop(user_id, None)
self.interruption_audio_count.pop(user_id, None)
except Exception as e:
logger.error(f"Failed to send audio chunk for user {user_id}: {e}")
async def _on_partial_transcript(self, user_id: int, text: str):
"""
Handle partial transcript from STT.
Args:
user_id: Discord user ID
text: Partial transcript text
"""
user = self.users.get(user_id)
logger.info(f"[VOICE_RECEIVER] Partial [{user.name if user else user_id}]: {text}")
# Notify voice manager
if hasattr(self.voice_manager, 'on_partial_transcript'):
await self.voice_manager.on_partial_transcript(user_id, text)
async def _on_final_transcript(self, user_id: int, text: str, user: discord.User):
"""
Handle final transcript from STT.
This triggers the LLM response generation.
Args:
user_id: Discord user ID
text: Final transcript text
user: Discord user object
"""
logger.info(f"[VOICE_RECEIVER] Final [{user.name if user else user_id}]: {text}")
# Notify voice manager - THIS TRIGGERS LLM RESPONSE
if hasattr(self.voice_manager, 'on_final_transcript'):
await self.voice_manager.on_final_transcript(user_id, text)
def get_listening_users(self) -> list:
"""
Get list of users currently being listened to.
Returns:
List of dicts with user_id, username, and connection status
"""
return [
{
'user_id': user_id,
'username': self.users.get(user_id, {}).name if self.users.get(user_id) else 'Unknown',
'connected': self.stt_clients.get(user_id, {}).connected if self.stt_clients.get(user_id) else False
}
for user_id in self.stt_clients.keys()
]
# Discord VAD events removed - we rely entirely on RealtimeSTT's VAD for speech detection