Files
miku-discord/stt/stt_server.py

364 lines
11 KiB
Python
Raw Normal View History

"""
STT Server
FastAPI WebSocket server for real-time speech-to-text.
Combines Silero VAD (CPU) and NVIDIA Parakeet (GPU) for efficient transcription.
Architecture:
- VAD runs continuously on every audio chunk (CPU)
- Parakeet transcribes only when VAD detects speech (GPU)
- Supports multiple concurrent users
- Sends partial and final transcripts via WebSocket with word-level tokens
"""
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
from fastapi.responses import JSONResponse
import numpy as np
import asyncio
import logging
from typing import Dict, Optional
from datetime import datetime
from vad_processor import VADProcessor
from parakeet_transcriber import ParakeetTranscriber
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='[%(levelname)s] [%(name)s] %(message)s'
)
logger = logging.getLogger('stt_server')
# Initialize FastAPI app
app = FastAPI(title="Miku STT Server", version="1.0.0")
# Global instances (initialized on startup)
vad_processor: Optional[VADProcessor] = None
parakeet_transcriber: Optional[ParakeetTranscriber] = None
# User session tracking
user_sessions: Dict[str, dict] = {}
class UserSTTSession:
"""Manages STT state for a single user."""
def __init__(self, user_id: str, websocket: WebSocket):
self.user_id = user_id
self.websocket = websocket
self.audio_buffer = []
self.is_speaking = False
self.timestamp_ms = 0.0
self.transcript_buffer = []
self.last_transcript = ""
logger.info(f"Created STT session for user {user_id}")
async def process_audio_chunk(self, audio_data: bytes):
"""
Process incoming audio chunk.
Args:
audio_data: Raw PCM audio (int16, 16kHz mono)
"""
# Convert bytes to numpy array (int16)
audio_np = np.frombuffer(audio_data, dtype=np.int16)
# Calculate timestamp (assuming 16kHz, 20ms chunks = 320 samples)
chunk_duration_ms = (len(audio_np) / 16000) * 1000
self.timestamp_ms += chunk_duration_ms
# Run VAD on chunk
vad_event = vad_processor.detect_speech_segment(audio_np, self.timestamp_ms)
if vad_event:
event_type = vad_event["event"]
probability = vad_event["probability"]
# Send VAD event to client
await self.websocket.send_json({
"type": "vad",
"event": event_type,
"speaking": event_type in ["speech_start", "speaking"],
"probability": probability,
"timestamp": self.timestamp_ms
})
# Handle speech events
if event_type == "speech_start":
self.is_speaking = True
self.audio_buffer = [audio_np]
logger.debug(f"User {self.user_id} started speaking")
elif event_type == "speaking":
if self.is_speaking:
self.audio_buffer.append(audio_np)
# Transcribe partial every ~2 seconds for streaming
total_samples = sum(len(chunk) for chunk in self.audio_buffer)
duration_s = total_samples / 16000
if duration_s >= 2.0:
await self._transcribe_partial()
elif event_type == "speech_end":
self.is_speaking = False
# Transcribe final
await self._transcribe_final()
# Clear buffer
self.audio_buffer = []
logger.debug(f"User {self.user_id} stopped speaking")
else:
# Still accumulate audio if speaking
if self.is_speaking:
self.audio_buffer.append(audio_np)
async def _transcribe_partial(self):
"""Transcribe accumulated audio and send partial result with word tokens."""
if not self.audio_buffer:
return
# Concatenate audio
audio_full = np.concatenate(self.audio_buffer)
# Transcribe asynchronously with word-level timestamps
try:
result = await parakeet_transcriber.transcribe_async(
audio_full,
sample_rate=16000,
return_timestamps=True
)
if result and result.get("text") and result["text"] != self.last_transcript:
self.last_transcript = result["text"]
# Send partial transcript with word tokens for LLM pre-computation
await self.websocket.send_json({
"type": "partial",
"text": result["text"],
"words": result.get("words", []), # Word-level tokens
"user_id": self.user_id,
"timestamp": self.timestamp_ms
})
logger.info(f"Partial [{self.user_id}]: {result['text']}")
except Exception as e:
logger.error(f"Partial transcription failed: {e}", exc_info=True)
async def _transcribe_final(self):
"""Transcribe final accumulated audio with word tokens."""
if not self.audio_buffer:
return
# Concatenate all audio
audio_full = np.concatenate(self.audio_buffer)
try:
result = await parakeet_transcriber.transcribe_async(
audio_full,
sample_rate=16000,
return_timestamps=True
)
if result and result.get("text"):
self.last_transcript = result["text"]
# Send final transcript with word tokens
await self.websocket.send_json({
"type": "final",
"text": result["text"],
"words": result.get("words", []), # Word-level tokens for LLM
"user_id": self.user_id,
"timestamp": self.timestamp_ms
})
logger.info(f"Final [{self.user_id}]: {result['text']}")
except Exception as e:
logger.error(f"Final transcription failed: {e}", exc_info=True)
async def check_interruption(self, audio_data: bytes) -> bool:
"""
Check if user is interrupting (for use during Miku's speech).
Args:
audio_data: Raw PCM audio chunk
Returns:
True if interruption detected
"""
audio_np = np.frombuffer(audio_data, dtype=np.int16)
speech_prob, is_speaking = vad_processor.process_chunk(audio_np)
# Interruption: high probability sustained for threshold duration
if speech_prob > 0.7: # Higher threshold for interruption
await self.websocket.send_json({
"type": "interruption",
"probability": speech_prob,
"timestamp": self.timestamp_ms
})
return True
return False
@app.on_event("startup")
async def startup_event():
"""Initialize models on server startup."""
global vad_processor, parakeet_transcriber
logger.info("=" * 50)
logger.info("Initializing Miku STT Server")
logger.info("=" * 50)
# Initialize VAD (CPU)
logger.info("Loading Silero VAD model (CPU)...")
vad_processor = VADProcessor(
sample_rate=16000,
threshold=0.5,
min_speech_duration_ms=250, # Conservative
min_silence_duration_ms=500 # Conservative
)
logger.info("✓ VAD ready")
# Initialize Parakeet (GPU)
logger.info("Loading NVIDIA Parakeet TDT model (GPU)...")
parakeet_transcriber = ParakeetTranscriber(
model_name="nvidia/parakeet-tdt-0.6b-v3",
device="cuda",
language="en"
)
logger.info("✓ Parakeet ready")
logger.info("=" * 50)
logger.info("STT Server ready to accept connections")
logger.info("=" * 50)
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on server shutdown."""
logger.info("Shutting down STT server...")
if parakeet_transcriber:
parakeet_transcriber.cleanup()
logger.info("STT server shutdown complete")
@app.get("/")
async def root():
"""Health check endpoint."""
return {
"service": "Miku STT Server",
"status": "running",
"vad_ready": vad_processor is not None,
"whisper_ready": whisper_transcriber is not None,
"active_sessions": len(user_sessions)
}
@app.get("/health")
async def health():
"""Detailed health check."""
return {
"status": "healthy",
"models": {
"vad": {
"loaded": vad_processor is not None,
"device": "cpu"
},
"whisper": {
"loaded": whisper_transcriber is not None,
"model": "small",
"device": "cuda"
}
},
"sessions": {
"active": len(user_sessions),
"users": list(user_sessions.keys())
}
}
@app.websocket("/ws/stt/{user_id}")
async def websocket_stt(websocket: WebSocket, user_id: str):
"""
WebSocket endpoint for real-time STT.
Client sends: Raw PCM audio (int16, 16kHz mono, 20ms chunks)
Server sends: JSON events:
- {"type": "vad", "event": "speech_start|speaking|speech_end", ...}
- {"type": "partial", "text": "...", ...}
- {"type": "final", "text": "...", ...}
- {"type": "interruption", "probability": 0.xx}
"""
await websocket.accept()
logger.info(f"STT WebSocket connected: user {user_id}")
# Create session
session = UserSTTSession(user_id, websocket)
user_sessions[user_id] = session
try:
# Send ready message
await websocket.send_json({
"type": "ready",
"user_id": user_id,
"message": "STT session started"
})
# Main loop: receive audio chunks
while True:
# Receive binary audio data
data = await websocket.receive_bytes()
# Process audio chunk
await session.process_audio_chunk(data)
except WebSocketDisconnect:
logger.info(f"User {user_id} disconnected")
except Exception as e:
logger.error(f"Error in STT WebSocket for user {user_id}: {e}", exc_info=True)
finally:
# Cleanup session
if user_id in user_sessions:
del user_sessions[user_id]
logger.info(f"STT session ended for user {user_id}")
@app.post("/interrupt/check")
async def check_interruption(user_id: str):
"""
Check if user is interrupting (for use during Miku's speech).
Query param:
user_id: Discord user ID
Returns:
{"interrupting": bool, "probability": float}
"""
session = user_sessions.get(user_id)
if not session:
raise HTTPException(status_code=404, detail="User session not found")
# Get current VAD state
vad_state = vad_processor.get_state()
return {
"interrupting": vad_state["speaking"],
"user_id": user_id
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")