Files
miku-discord/bot/utils/voice_audio.py

396 lines
15 KiB
Python

# voice_audio.py
"""
Audio streaming bridge between RVC TTS and Discord voice.
Uses aiohttp for WebSocket communication (compatible with FastAPI).
"""
import asyncio
import json
import numpy as np
from typing import Optional
import discord
import aiohttp
from utils.logger import get_logger
import asyncio
import struct
import json
import websockets
import discord
import numpy as np
from typing import Optional
from utils.logger import get_logger
logger = get_logger('voice_audio')
# Audio format constants
SAMPLE_RATE = 48000 # 48kHz
CHANNELS = 2 # Stereo for Discord
FRAME_LENGTH = 0.02 # 20ms frames
SAMPLES_PER_FRAME = int(SAMPLE_RATE * FRAME_LENGTH) # 960 samples
class MikuVoiceSource(discord.AudioSource):
"""
Audio source that receives audio from RVC TTS WebSocket and feeds it to Discord voice.
Single WebSocket connection handles both token sending and audio receiving.
Uses aiohttp for WebSocket communication (compatible with FastAPI).
"""
def __init__(self):
self.websocket_url = "ws://172.25.0.1:8765/ws/stream"
self.health_url = "http://172.25.0.1:8765/health"
self.session = None
self.websocket = None
self.audio_buffer = bytearray()
self.buffer_lock = asyncio.Lock()
self.running = False
self._receive_task = None
self.warmed_up = False # Track if TTS pipeline is warmed up
self.token_queue = [] # Queue tokens while warming up or connecting
async def _check_rvc_ready(self) -> bool:
"""Check if RVC is initialized and warmed up via health endpoint"""
try:
async with aiohttp.ClientSession() as session:
async with session.get(self.health_url, timeout=aiohttp.ClientTimeout(total=2)) as response:
if response.status == 200:
data = await response.json()
return data.get("warmed_up", False)
return False
except Exception as e:
logger.debug(f"Health check failed: {e}")
return False
async def connect(self):
"""Connect to RVC TTS WebSocket using aiohttp"""
try:
# First, check if RVC is warmed up
logger.info("Checking if RVC is ready...")
is_ready = await self._check_rvc_ready()
if not is_ready:
logger.warning("⏳ RVC is warming up, will queue tokens until ready")
self.warmed_up = False
# Don't connect yet - we'll connect later when ready
# Start a background task to poll and connect when ready
self._receive_task = asyncio.create_task(self._wait_for_warmup_and_connect())
return
# RVC is ready, connect immediately
logger.info("RVC is ready, connecting...")
await self._do_connect()
self.warmed_up = True
except Exception as e:
logger.error(f"Failed to initialize connection: {e}", exc_info=True)
raise
async def _do_connect(self):
"""Actually establish the WebSocket connection"""
self.session = aiohttp.ClientSession()
self.websocket = await self.session.ws_connect(self.websocket_url)
self.running = True
logger.info("✓ Connected to RVC TTS WebSocket")
# Always start background task to receive audio after connecting
# (Don't check if _receive_task exists - it might be the warmup polling task)
self._receive_task = asyncio.create_task(self._receive_audio())
async def _wait_for_warmup_and_connect(self):
"""Poll RVC health until warmed up, then connect and flush queue"""
try:
logger.info("Polling RVC for warmup completion...")
max_wait = 60 # 60 seconds max
poll_interval = 1.0 # Check every second
for _ in range(int(max_wait / poll_interval)):
if await self._check_rvc_ready():
logger.info("✅ RVC warmup complete! Connecting and flushing queue...")
await self._do_connect()
self.warmed_up = True
# Flush queued tokens
if self.token_queue and self.websocket:
logger.info(f"Sending {len(self.token_queue)} queued tokens")
for token, pitch_shift in self.token_queue:
await self.websocket.send_json({
"token": token,
"pitch_shift": pitch_shift
})
# Small delay to avoid overwhelming RVC
await asyncio.sleep(0.05)
# Send flush command to ensure any buffered text is synthesized
await self.websocket.send_json({"flush": True})
logger.info("✓ Queue flushed with explicit flush command")
self.token_queue.clear()
return
await asyncio.sleep(poll_interval)
# Timeout
logger.error("❌ RVC warmup timeout! Connecting anyway...")
await self._do_connect()
self.warmed_up = True
except asyncio.CancelledError:
logger.debug("Warmup polling cancelled")
except Exception as e:
logger.error(f"Error during warmup wait: {e}", exc_info=True)
async def _reconnect(self):
"""Attempt to reconnect after connection failure"""
try:
logger.info("Reconnection attempt starting...")
max_retries = 5
retry_delay = 3.0
for attempt in range(max_retries):
try:
# Clean up old connection
if self.websocket:
try:
await self.websocket.close()
except:
pass
self.websocket = None
if self.session:
try:
await self.session.close()
except:
pass
self.session = None
# Wait before retry
if attempt > 0:
logger.info(f"Retry {attempt + 1}/{max_retries} in {retry_delay}s...")
await asyncio.sleep(retry_delay)
# Check if RVC is ready
if not await self._check_rvc_ready():
logger.warning("RVC not ready, will retry...")
continue
# Try to connect
await self._do_connect()
self.warmed_up = True
# Flush queued tokens
if self.token_queue and self.websocket:
logger.info(f"✓ Reconnected! Flushing {len(self.token_queue)} queued tokens")
for token, pitch_shift in self.token_queue:
await self.websocket.send_json({
"token": token,
"pitch_shift": pitch_shift
})
await asyncio.sleep(0.05)
# Send flush command to ensure any buffered text is synthesized
await self.websocket.send_json({"flush": True})
self.token_queue.clear()
logger.info("✓ Queue flushed with explicit flush command")
logger.info("✓ Reconnection successful")
return
except Exception as e:
logger.error(f"Reconnection attempt {attempt + 1} failed: {e}")
logger.error(f"❌ Failed to reconnect after {max_retries} attempts")
except asyncio.CancelledError:
logger.debug("Reconnection cancelled")
except Exception as e:
logger.error(f"Error during reconnection: {e}", exc_info=True)
async def disconnect(self):
"""Disconnect from WebSocket"""
self.running = False
if self._receive_task:
self._receive_task.cancel()
try:
await self._receive_task
except asyncio.CancelledError:
pass
self._receive_task = None
if self.websocket:
await self.websocket.close()
self.websocket = None
if self.session:
await self.session.close()
self.session = None
logger.info("Disconnected from RVC TTS WebSocket")
async def send_token(self, token: str, pitch_shift: int = 0):
"""
Send a text token to TTS for voice generation.
Queues tokens if pipeline is still warming up or connection failed.
Args:
token: Text token to synthesize
pitch_shift: Pitch adjustment (-12 to +12 semitones)
"""
# If not warmed up yet or no connection, queue the token
if not self.warmed_up or not self.websocket:
self.token_queue.append((token, pitch_shift))
if not self.warmed_up:
logger.debug(f"Queued token (warming up): '{token}' (queue size: {len(self.token_queue)})")
else:
logger.debug(f"Queued token (no connection): '{token}' (queue size: {len(self.token_queue)})")
# Try to reconnect in background if not already trying
if not self._receive_task or self._receive_task.done():
logger.info("Attempting to reconnect to RVC...")
self._receive_task = asyncio.create_task(self._reconnect())
return
try:
message = {
"token": token,
"pitch_shift": pitch_shift
}
await self.websocket.send_json(message)
logger.debug(f"Sent token to TTS: '{token}'")
except Exception as e:
logger.error(f"Failed to send token: {e}")
# Queue the failed token and mark as not warmed up to trigger reconnection
self.token_queue.append((token, pitch_shift))
self.warmed_up = False
if self.websocket:
try:
await self.websocket.close()
except:
pass
self.websocket = None
async def stream_text(self, text: str, pitch_shift: int = 0):
"""
Stream entire text to TTS word-by-word.
Args:
text: Full text to synthesize
pitch_shift: Pitch adjustment
"""
words = text.split()
for word in words:
await self.send_token(word + " ", pitch_shift)
# Small delay to avoid overwhelming the TTS
await asyncio.sleep(0.05)
async def flush(self):
"""
Send flush command to TTS to trigger synthesis of buffered tokens.
This ensures any remaining text in the TTS buffer is synthesized.
"""
if self.websocket:
try:
await self.websocket.send_json({"flush": True})
logger.debug("Sent flush command to TTS")
except Exception as e:
logger.error(f"Failed to send flush command: {e}")
async def clear_buffer(self):
"""
Clear the audio buffer without disconnecting.
Used when interrupting playback to avoid playing old audio.
"""
async with self.buffer_lock:
self.audio_buffer.clear()
logger.debug("Audio buffer cleared")
async def _receive_audio(self):
"""Background task to receive audio from WebSocket and buffer it."""
try:
while self.running and self.websocket:
try:
# Receive message from WebSocket
msg = await self.websocket.receive()
if msg.type == aiohttp.WSMsgType.BINARY:
# Convert float32 mono → int16 stereo
converted = self._convert_audio(msg.data)
self.audio_buffer.extend(converted)
logger.debug(f"Received {len(msg.data)} bytes, buffer: {len(self.audio_buffer)} bytes")
elif msg.type == aiohttp.WSMsgType.CLOSED:
logger.warning("TTS WebSocket connection closed")
break
elif msg.type == aiohttp.WSMsgType.ERROR:
logger.error(f"WebSocket error: {self.websocket.exception()}")
break
except Exception as e:
logger.error(f"Error receiving audio: {e}", exc_info=True)
break
except asyncio.CancelledError:
logger.debug("Audio receive task cancelled")
def _convert_audio(self, float32_mono: bytes) -> bytes:
"""
Convert float32 mono PCM to int16 stereo PCM.
Args:
float32_mono: Raw PCM audio (float32 values, mono channel)
Returns:
int16 stereo PCM bytes
"""
# Parse float32 values
num_samples = len(float32_mono) // 4 # 4 bytes per float32
float_array = struct.unpack(f'{num_samples}f', float32_mono)
# Convert to numpy for easier processing
audio_np = np.array(float_array, dtype=np.float32)
# Clamp to [-1.0, 1.0] range
audio_np = np.clip(audio_np, -1.0, 1.0)
# Convert to int16 range [-32768, 32767]
audio_int16 = (audio_np * 32767).astype(np.int16)
# Duplicate mono channel to stereo (L and R same)
stereo = np.repeat(audio_int16, 2)
# Convert to bytes
return stereo.tobytes()
def read(self) -> bytes:
"""
Read 20ms of audio (required by discord.py AudioSource interface).
Discord expects exactly 960 samples per channel (1920 samples total for stereo),
which equals 3840 bytes (1920 samples * 2 bytes per int16).
Returns:
3840 bytes of int16 stereo PCM, or empty bytes if no audio available
"""
# Calculate required bytes for 20ms frame
bytes_needed = SAMPLES_PER_FRAME * CHANNELS * 2 # 960 * 2 * 2 = 3840 bytes
if len(self.audio_buffer) >= bytes_needed:
# Extract frame from buffer
frame = bytes(self.audio_buffer[:bytes_needed])
del self.audio_buffer[:bytes_needed]
return frame
else:
# Not enough audio yet, return silence
return b'\x00' * bytes_needed
def is_opus(self) -> bool:
"""Return False since we're providing raw PCM."""
return False
def cleanup(self):
"""Cleanup resources when AudioSource is done."""
logger.info("MikuVoiceSource cleanup called")
# Actual disconnect happens via disconnect() method