Phase 2 implemented and tested. Added warmup to pipeline and Miku queues tokens while the pipeline is warming up
This commit is contained in:
@@ -125,7 +125,7 @@ async def on_message(message):
|
||||
if message.author == globals.client.user:
|
||||
return
|
||||
|
||||
# Check for voice commands first (!miku join, !miku leave, !miku voice-status)
|
||||
# Check for voice commands first (!miku join, !miku leave, !miku voice-status, !miku test)
|
||||
if not isinstance(message.channel, discord.DMChannel) and message.content.strip().lower().startswith('!miku '):
|
||||
from commands.voice import handle_voice_command
|
||||
|
||||
@@ -134,7 +134,7 @@ async def on_message(message):
|
||||
cmd = parts[1].lower()
|
||||
args = parts[2:] if len(parts) > 2 else []
|
||||
|
||||
if cmd in ['join', 'leave', 'voice-status']:
|
||||
if cmd in ['join', 'leave', 'voice-status', 'test']:
|
||||
await handle_voice_command(message, cmd, args)
|
||||
return
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ async def handle_voice_command(message, cmd, args):
|
||||
|
||||
Args:
|
||||
message: Discord message object
|
||||
cmd: Command name (join, leave, voice-status)
|
||||
cmd: Command name (join, leave, voice-status, test)
|
||||
args: Command arguments
|
||||
"""
|
||||
|
||||
@@ -30,6 +30,9 @@ async def handle_voice_command(message, cmd, args):
|
||||
elif cmd == 'voice-status':
|
||||
await _handle_status(message)
|
||||
|
||||
elif cmd == 'test':
|
||||
await _handle_test(message, args)
|
||||
|
||||
else:
|
||||
await message.channel.send(f"❌ Unknown voice command: `{cmd}`")
|
||||
|
||||
@@ -227,3 +230,36 @@ async def _handle_status(message):
|
||||
embed.set_footer(text="Use !miku leave to end the session")
|
||||
|
||||
await message.channel.send(embed=embed)
|
||||
|
||||
|
||||
async def _handle_test(message, args):
|
||||
"""
|
||||
Handle !miku test command.
|
||||
Test TTS audio playback in the current voice session.
|
||||
"""
|
||||
session = voice_manager.active_session
|
||||
|
||||
if not session:
|
||||
await message.channel.send("❌ No active voice session! Use `!miku join` first.")
|
||||
return
|
||||
|
||||
if not session.audio_source:
|
||||
await message.channel.send("❌ Audio source not connected!")
|
||||
return
|
||||
|
||||
# Get test text from args or use default
|
||||
test_text = " ".join(args) if args else "Hello! This is a test of my voice chat system."
|
||||
|
||||
try:
|
||||
await message.channel.send(f"🎤 Speaking: *\"{test_text}\"*")
|
||||
logger.info(f"Testing voice playback: {test_text}")
|
||||
|
||||
# Stream text to TTS via the audio source
|
||||
await session.audio_source.stream_text(test_text)
|
||||
|
||||
await message.add_reaction("✅")
|
||||
logger.info("✓ Test audio sent to TTS")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test voice playback: {e}", exc_info=True)
|
||||
await message.channel.send(f"❌ Error testing voice: {e}")
|
||||
|
||||
@@ -61,6 +61,7 @@ COMPONENTS = {
|
||||
'apscheduler': 'Job scheduler logs (APScheduler)',
|
||||
'voice_manager': 'Voice channel session management',
|
||||
'voice_commands': 'Voice channel commands',
|
||||
'voice_audio': 'Voice audio streaming and TTS',
|
||||
}
|
||||
|
||||
# Global configuration
|
||||
|
||||
373
bot/utils/voice_audio.py
Normal file
373
bot/utils/voice_audio.py
Normal file
@@ -0,0 +1,373 @@
|
||||
# voice_audio.py
|
||||
"""
|
||||
Audio streaming bridge between RVC TTS and Discord voice.
|
||||
Uses aiohttp for WebSocket communication (compatible with FastAPI).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
import discord
|
||||
import aiohttp
|
||||
from utils.logger import get_logger
|
||||
|
||||
import asyncio
|
||||
import struct
|
||||
import json
|
||||
import websockets
|
||||
import discord
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
from utils.logger import get_logger
|
||||
|
||||
logger = get_logger('voice_audio')
|
||||
|
||||
# Audio format constants
|
||||
SAMPLE_RATE = 48000 # 48kHz
|
||||
CHANNELS = 2 # Stereo for Discord
|
||||
FRAME_LENGTH = 0.02 # 20ms frames
|
||||
SAMPLES_PER_FRAME = int(SAMPLE_RATE * FRAME_LENGTH) # 960 samples
|
||||
|
||||
|
||||
class MikuVoiceSource(discord.AudioSource):
|
||||
"""
|
||||
Audio source that receives audio from RVC TTS WebSocket and feeds it to Discord voice.
|
||||
Single WebSocket connection handles both token sending and audio receiving.
|
||||
Uses aiohttp for WebSocket communication (compatible with FastAPI).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.websocket_url = "ws://172.25.0.1:8765/ws/stream"
|
||||
self.health_url = "http://172.25.0.1:8765/health"
|
||||
self.session = None
|
||||
self.websocket = None
|
||||
self.audio_buffer = bytearray()
|
||||
self.buffer_lock = asyncio.Lock()
|
||||
self.running = False
|
||||
self._receive_task = None
|
||||
self.warmed_up = False # Track if TTS pipeline is warmed up
|
||||
self.token_queue = [] # Queue tokens while warming up or connecting
|
||||
|
||||
async def _check_rvc_ready(self) -> bool:
|
||||
"""Check if RVC is initialized and warmed up via health endpoint"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(self.health_url, timeout=aiohttp.ClientTimeout(total=2)) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return data.get("warmed_up", False)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug(f"Health check failed: {e}")
|
||||
return False
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to RVC TTS WebSocket using aiohttp"""
|
||||
try:
|
||||
# First, check if RVC is warmed up
|
||||
logger.info("Checking if RVC is ready...")
|
||||
is_ready = await self._check_rvc_ready()
|
||||
|
||||
if not is_ready:
|
||||
logger.warning("⏳ RVC is warming up, will queue tokens until ready")
|
||||
self.warmed_up = False
|
||||
# Don't connect yet - we'll connect later when ready
|
||||
# Start a background task to poll and connect when ready
|
||||
self._receive_task = asyncio.create_task(self._wait_for_warmup_and_connect())
|
||||
return
|
||||
|
||||
# RVC is ready, connect immediately
|
||||
logger.info("RVC is ready, connecting...")
|
||||
await self._do_connect()
|
||||
self.warmed_up = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize connection: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def _do_connect(self):
|
||||
"""Actually establish the WebSocket connection"""
|
||||
self.session = aiohttp.ClientSession()
|
||||
self.websocket = await self.session.ws_connect(self.websocket_url)
|
||||
self.running = True
|
||||
logger.info("✓ Connected to RVC TTS WebSocket")
|
||||
|
||||
# Always start background task to receive audio after connecting
|
||||
# (Don't check if _receive_task exists - it might be the warmup polling task)
|
||||
self._receive_task = asyncio.create_task(self._receive_audio())
|
||||
|
||||
async def _wait_for_warmup_and_connect(self):
|
||||
"""Poll RVC health until warmed up, then connect and flush queue"""
|
||||
try:
|
||||
logger.info("Polling RVC for warmup completion...")
|
||||
max_wait = 60 # 60 seconds max
|
||||
poll_interval = 1.0 # Check every second
|
||||
|
||||
for _ in range(int(max_wait / poll_interval)):
|
||||
if await self._check_rvc_ready():
|
||||
logger.info("✅ RVC warmup complete! Connecting and flushing queue...")
|
||||
await self._do_connect()
|
||||
self.warmed_up = True
|
||||
|
||||
# Flush queued tokens
|
||||
if self.token_queue and self.websocket:
|
||||
logger.info(f"Sending {len(self.token_queue)} queued tokens")
|
||||
for token, pitch_shift in self.token_queue:
|
||||
await self.websocket.send_json({
|
||||
"token": token,
|
||||
"pitch_shift": pitch_shift
|
||||
})
|
||||
# Small delay to avoid overwhelming RVC
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Send flush command to ensure any buffered text is synthesized
|
||||
await self.websocket.send_json({"flush": True})
|
||||
logger.info("✓ Queue flushed with explicit flush command")
|
||||
self.token_queue.clear()
|
||||
return
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
# Timeout
|
||||
logger.error("❌ RVC warmup timeout! Connecting anyway...")
|
||||
await self._do_connect()
|
||||
self.warmed_up = True
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Warmup polling cancelled")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during warmup wait: {e}", exc_info=True)
|
||||
|
||||
async def _reconnect(self):
|
||||
"""Attempt to reconnect after connection failure"""
|
||||
try:
|
||||
logger.info("Reconnection attempt starting...")
|
||||
max_retries = 5
|
||||
retry_delay = 3.0
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Clean up old connection
|
||||
if self.websocket:
|
||||
try:
|
||||
await self.websocket.close()
|
||||
except:
|
||||
pass
|
||||
self.websocket = None
|
||||
|
||||
if self.session:
|
||||
try:
|
||||
await self.session.close()
|
||||
except:
|
||||
pass
|
||||
self.session = None
|
||||
|
||||
# Wait before retry
|
||||
if attempt > 0:
|
||||
logger.info(f"Retry {attempt + 1}/{max_retries} in {retry_delay}s...")
|
||||
await asyncio.sleep(retry_delay)
|
||||
|
||||
# Check if RVC is ready
|
||||
if not await self._check_rvc_ready():
|
||||
logger.warning("RVC not ready, will retry...")
|
||||
continue
|
||||
|
||||
# Try to connect
|
||||
await self._do_connect()
|
||||
self.warmed_up = True
|
||||
|
||||
# Flush queued tokens
|
||||
if self.token_queue and self.websocket:
|
||||
logger.info(f"✓ Reconnected! Flushing {len(self.token_queue)} queued tokens")
|
||||
for token, pitch_shift in self.token_queue:
|
||||
await self.websocket.send_json({
|
||||
"token": token,
|
||||
"pitch_shift": pitch_shift
|
||||
})
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Send flush command to ensure any buffered text is synthesized
|
||||
await self.websocket.send_json({"flush": True})
|
||||
self.token_queue.clear()
|
||||
logger.info("✓ Queue flushed with explicit flush command")
|
||||
logger.info("✓ Reconnection successful")
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Reconnection attempt {attempt + 1} failed: {e}")
|
||||
|
||||
logger.error(f"❌ Failed to reconnect after {max_retries} attempts")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Reconnection cancelled")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during reconnection: {e}", exc_info=True)
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from WebSocket"""
|
||||
self.running = False
|
||||
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._receive_task = None
|
||||
|
||||
if self.websocket:
|
||||
await self.websocket.close()
|
||||
self.websocket = None
|
||||
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
self.session = None
|
||||
|
||||
logger.info("Disconnected from RVC TTS WebSocket")
|
||||
|
||||
async def send_token(self, token: str, pitch_shift: int = 0):
|
||||
"""
|
||||
Send a text token to TTS for voice generation.
|
||||
Queues tokens if pipeline is still warming up or connection failed.
|
||||
|
||||
Args:
|
||||
token: Text token to synthesize
|
||||
pitch_shift: Pitch adjustment (-12 to +12 semitones)
|
||||
"""
|
||||
# If not warmed up yet or no connection, queue the token
|
||||
if not self.warmed_up or not self.websocket:
|
||||
self.token_queue.append((token, pitch_shift))
|
||||
if not self.warmed_up:
|
||||
logger.debug(f"Queued token (warming up): '{token}' (queue size: {len(self.token_queue)})")
|
||||
else:
|
||||
logger.debug(f"Queued token (no connection): '{token}' (queue size: {len(self.token_queue)})")
|
||||
# Try to reconnect in background if not already trying
|
||||
if not self._receive_task or self._receive_task.done():
|
||||
logger.info("Attempting to reconnect to RVC...")
|
||||
self._receive_task = asyncio.create_task(self._reconnect())
|
||||
return
|
||||
|
||||
try:
|
||||
message = {
|
||||
"token": token,
|
||||
"pitch_shift": pitch_shift
|
||||
}
|
||||
await self.websocket.send_json(message)
|
||||
logger.debug(f"Sent token to TTS: '{token}'")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send token: {e}")
|
||||
# Queue the failed token and mark as not warmed up to trigger reconnection
|
||||
self.token_queue.append((token, pitch_shift))
|
||||
self.warmed_up = False
|
||||
if self.websocket:
|
||||
try:
|
||||
await self.websocket.close()
|
||||
except:
|
||||
pass
|
||||
self.websocket = None
|
||||
|
||||
async def stream_text(self, text: str, pitch_shift: int = 0):
|
||||
"""
|
||||
Stream entire text to TTS word-by-word.
|
||||
|
||||
Args:
|
||||
text: Full text to synthesize
|
||||
pitch_shift: Pitch adjustment
|
||||
"""
|
||||
words = text.split()
|
||||
for word in words:
|
||||
await self.send_token(word + " ", pitch_shift)
|
||||
# Small delay to avoid overwhelming the TTS
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
|
||||
async def _receive_audio(self):
|
||||
"""Background task to receive audio from WebSocket and buffer it."""
|
||||
try:
|
||||
while self.running and self.websocket:
|
||||
try:
|
||||
# Receive message from WebSocket
|
||||
msg = await self.websocket.receive()
|
||||
|
||||
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||
# Convert float32 mono → int16 stereo
|
||||
converted = self._convert_audio(msg.data)
|
||||
self.audio_buffer.extend(converted)
|
||||
|
||||
logger.debug(f"Received {len(msg.data)} bytes, buffer: {len(self.audio_buffer)} bytes")
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
logger.warning("TTS WebSocket connection closed")
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
logger.error(f"WebSocket error: {self.websocket.exception()}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving audio: {e}", exc_info=True)
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Audio receive task cancelled")
|
||||
|
||||
def _convert_audio(self, float32_mono: bytes) -> bytes:
|
||||
"""
|
||||
Convert float32 mono PCM to int16 stereo PCM.
|
||||
|
||||
Args:
|
||||
float32_mono: Raw PCM audio (float32 values, mono channel)
|
||||
|
||||
Returns:
|
||||
int16 stereo PCM bytes
|
||||
"""
|
||||
# Parse float32 values
|
||||
num_samples = len(float32_mono) // 4 # 4 bytes per float32
|
||||
float_array = struct.unpack(f'{num_samples}f', float32_mono)
|
||||
|
||||
# Convert to numpy for easier processing
|
||||
audio_np = np.array(float_array, dtype=np.float32)
|
||||
|
||||
# Clamp to [-1.0, 1.0] range
|
||||
audio_np = np.clip(audio_np, -1.0, 1.0)
|
||||
|
||||
# Convert to int16 range [-32768, 32767]
|
||||
audio_int16 = (audio_np * 32767).astype(np.int16)
|
||||
|
||||
# Duplicate mono channel to stereo (L and R same)
|
||||
stereo = np.repeat(audio_int16, 2)
|
||||
|
||||
# Convert to bytes
|
||||
return stereo.tobytes()
|
||||
|
||||
def read(self) -> bytes:
|
||||
"""
|
||||
Read 20ms of audio (required by discord.py AudioSource interface).
|
||||
|
||||
Discord expects exactly 960 samples per channel (1920 samples total for stereo),
|
||||
which equals 3840 bytes (1920 samples * 2 bytes per int16).
|
||||
|
||||
Returns:
|
||||
3840 bytes of int16 stereo PCM, or empty bytes if no audio available
|
||||
"""
|
||||
# Calculate required bytes for 20ms frame
|
||||
bytes_needed = SAMPLES_PER_FRAME * CHANNELS * 2 # 960 * 2 * 2 = 3840 bytes
|
||||
|
||||
if len(self.audio_buffer) >= bytes_needed:
|
||||
# Extract frame from buffer
|
||||
frame = bytes(self.audio_buffer[:bytes_needed])
|
||||
del self.audio_buffer[:bytes_needed]
|
||||
return frame
|
||||
else:
|
||||
# Not enough audio yet, return silence
|
||||
return b'\x00' * bytes_needed
|
||||
|
||||
def is_opus(self) -> bool:
|
||||
"""Return False since we're providing raw PCM."""
|
||||
return False
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources when AudioSource is done."""
|
||||
logger.info("MikuVoiceSource cleanup called")
|
||||
# Actual disconnect happens via disconnect() method
|
||||
@@ -107,6 +107,14 @@ class VoiceSessionManager:
|
||||
logger.error(f"Failed to connect to voice channel: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
# 12. Start audio streaming (Phase 2)
|
||||
try:
|
||||
await self.active_session.start_audio_streaming()
|
||||
logger.info(f"✓ Audio streaming started")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start audio streaming: {e}", exc_info=True)
|
||||
# Continue anyway - audio streaming is optional for Phase 2 testing
|
||||
|
||||
logger.info(f"✓ Voice session started successfully")
|
||||
|
||||
except Exception as e:
|
||||
@@ -127,7 +135,14 @@ class VoiceSessionManager:
|
||||
logger.info("Ending voice session...")
|
||||
|
||||
try:
|
||||
# 1. Disconnect from voice channel
|
||||
# 1. Stop audio streaming
|
||||
if self.active_session:
|
||||
try:
|
||||
await self.active_session.stop_audio_streaming()
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping audio streaming: {e}")
|
||||
|
||||
# 2. Disconnect from voice channel
|
||||
if self.active_session.voice_client:
|
||||
try:
|
||||
await self.active_session.voice_client.disconnect()
|
||||
@@ -135,28 +150,28 @@ class VoiceSessionManager:
|
||||
except Exception as e:
|
||||
logger.error(f"Error disconnecting from voice: {e}")
|
||||
|
||||
# 2. Resume text channel inference
|
||||
# 3. Resume text channel inference
|
||||
await self._resume_text_channels()
|
||||
|
||||
# 3. Unblock vision model
|
||||
# 4. Unblock vision model
|
||||
await self._unblock_vision_model()
|
||||
|
||||
# 4. Re-enable image generation
|
||||
# 5. Re-enable image generation
|
||||
await self._enable_image_generation()
|
||||
|
||||
# 5. Re-enable bipolar mode interactions
|
||||
# 6. Re-enable bipolar mode interactions
|
||||
await self._enable_bipolar_mode()
|
||||
|
||||
# 6. Re-enable profile picture switching
|
||||
# 7. Re-enable profile picture switching
|
||||
await self._enable_profile_picture_switching()
|
||||
|
||||
# 7. Resume autonomous engine
|
||||
# 8. Resume autonomous engine
|
||||
await self._resume_autonomous_engine()
|
||||
|
||||
# 8. Resume scheduled events
|
||||
# 9. Resume scheduled events
|
||||
await self._resume_scheduled_events()
|
||||
|
||||
# 9. Resume figurine notifier
|
||||
# 10. Resume figurine notifier
|
||||
await self._resume_figurine_notifier()
|
||||
|
||||
# 10. Clear active session
|
||||
@@ -362,8 +377,7 @@ class VoiceSessionManager:
|
||||
|
||||
class VoiceSession:
|
||||
"""
|
||||
Represents an active voice chat session.
|
||||
Phase 1: Basic structure only, voice connection in Phase 2.
|
||||
Represents an active voice chat session with audio streaming.
|
||||
"""
|
||||
|
||||
def __init__(self, guild_id: int, voice_channel: discord.VoiceChannel, text_channel: discord.TextChannel):
|
||||
@@ -371,11 +385,54 @@ class VoiceSession:
|
||||
self.voice_channel = voice_channel
|
||||
self.text_channel = text_channel
|
||||
self.voice_client: Optional[discord.VoiceClient] = None
|
||||
self.audio_source: Optional['MikuVoiceSource'] = None # Forward reference
|
||||
self.tts_streamer: Optional['TTSTokenStreamer'] = None # Forward reference
|
||||
self.active = False
|
||||
|
||||
logger.info(f"VoiceSession created for {voice_channel.name} in guild {guild_id}")
|
||||
|
||||
# Phase 2: Implement voice connection, audio streaming, TTS integration
|
||||
async def start_audio_streaming(self):
|
||||
"""
|
||||
Start audio streaming from TTS WebSocket to Discord voice.
|
||||
This should be called after voice_client is connected.
|
||||
"""
|
||||
from utils.voice_audio import MikuVoiceSource
|
||||
|
||||
try:
|
||||
# Create and connect audio source (handles both sending tokens and receiving audio)
|
||||
self.audio_source = MikuVoiceSource()
|
||||
await self.audio_source.connect()
|
||||
|
||||
# The audio_source now serves as both the audio source AND the token sender
|
||||
# Set tts_streamer to point to audio_source for backwards compatibility
|
||||
self.tts_streamer = self.audio_source
|
||||
|
||||
# Start playing audio to Discord
|
||||
if self.voice_client and not self.voice_client.is_playing():
|
||||
self.voice_client.play(self.audio_source)
|
||||
logger.info("✓ Started audio streaming to Discord")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start audio streaming: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def stop_audio_streaming(self):
|
||||
"""Stop audio streaming and cleanup resources."""
|
||||
try:
|
||||
# Stop Discord audio playback
|
||||
if self.voice_client and self.voice_client.is_playing():
|
||||
self.voice_client.stop()
|
||||
|
||||
# Disconnect audio source (which also handles token streaming)
|
||||
if self.audio_source:
|
||||
await self.audio_source.disconnect()
|
||||
self.audio_source = None
|
||||
self.tts_streamer = None # Clear reference since it pointed to audio_source
|
||||
|
||||
logger.info("✓ Stopped audio streaming")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping audio streaming: {e}", exc_info=True)
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
|
||||
Reference in New Issue
Block a user