362 lines
11 KiB
Python
362 lines
11 KiB
Python
"""
|
|
STT Server
|
|
|
|
FastAPI WebSocket server for real-time speech-to-text.
|
|
Combines Silero VAD (CPU) and Faster-Whisper (GPU) for efficient transcription.
|
|
|
|
Architecture:
|
|
- VAD runs continuously on every audio chunk (CPU)
|
|
- Whisper transcribes only when VAD detects speech (GPU)
|
|
- Supports multiple concurrent users
|
|
- Sends partial and final transcripts via WebSocket
|
|
"""
|
|
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
|
|
from fastapi.responses import JSONResponse
|
|
import numpy as np
|
|
import asyncio
|
|
import logging
|
|
from typing import Dict, Optional
|
|
from datetime import datetime
|
|
|
|
from vad_processor import VADProcessor
|
|
from whisper_transcriber import WhisperTranscriber
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='[%(levelname)s] [%(name)s] %(message)s'
|
|
)
|
|
logger = logging.getLogger('stt_server')
|
|
|
|
# Initialize FastAPI app
|
|
app = FastAPI(title="Miku STT Server", version="1.0.0")
|
|
|
|
# Global instances (initialized on startup)
|
|
vad_processor: Optional[VADProcessor] = None
|
|
whisper_transcriber: Optional[WhisperTranscriber] = None
|
|
|
|
# User session tracking
|
|
user_sessions: Dict[str, dict] = {}
|
|
|
|
|
|
class UserSTTSession:
|
|
"""Manages STT state for a single user."""
|
|
|
|
def __init__(self, user_id: str, websocket: WebSocket):
|
|
self.user_id = user_id
|
|
self.websocket = websocket
|
|
self.audio_buffer = []
|
|
self.is_speaking = False
|
|
self.timestamp_ms = 0.0
|
|
self.transcript_buffer = []
|
|
self.last_transcript = ""
|
|
|
|
logger.info(f"Created STT session for user {user_id}")
|
|
|
|
async def process_audio_chunk(self, audio_data: bytes):
|
|
"""
|
|
Process incoming audio chunk.
|
|
|
|
Args:
|
|
audio_data: Raw PCM audio (int16, 16kHz mono)
|
|
"""
|
|
# Convert bytes to numpy array (int16)
|
|
audio_np = np.frombuffer(audio_data, dtype=np.int16)
|
|
|
|
# Calculate timestamp (assuming 16kHz, 20ms chunks = 320 samples)
|
|
chunk_duration_ms = (len(audio_np) / 16000) * 1000
|
|
self.timestamp_ms += chunk_duration_ms
|
|
|
|
# Run VAD on chunk
|
|
vad_event = vad_processor.detect_speech_segment(audio_np, self.timestamp_ms)
|
|
|
|
if vad_event:
|
|
event_type = vad_event["event"]
|
|
probability = vad_event["probability"]
|
|
|
|
# Send VAD event to client
|
|
await self.websocket.send_json({
|
|
"type": "vad",
|
|
"event": event_type,
|
|
"speaking": event_type in ["speech_start", "speaking"],
|
|
"probability": probability,
|
|
"timestamp": self.timestamp_ms
|
|
})
|
|
|
|
# Handle speech events
|
|
if event_type == "speech_start":
|
|
self.is_speaking = True
|
|
self.audio_buffer = [audio_np]
|
|
logger.debug(f"User {self.user_id} started speaking")
|
|
|
|
elif event_type == "speaking":
|
|
if self.is_speaking:
|
|
self.audio_buffer.append(audio_np)
|
|
|
|
# Transcribe partial every ~2 seconds for streaming
|
|
total_samples = sum(len(chunk) for chunk in self.audio_buffer)
|
|
duration_s = total_samples / 16000
|
|
|
|
if duration_s >= 2.0:
|
|
await self._transcribe_partial()
|
|
|
|
elif event_type == "speech_end":
|
|
self.is_speaking = False
|
|
|
|
# Transcribe final
|
|
await self._transcribe_final()
|
|
|
|
# Clear buffer
|
|
self.audio_buffer = []
|
|
logger.debug(f"User {self.user_id} stopped speaking")
|
|
|
|
else:
|
|
# Still accumulate audio if speaking
|
|
if self.is_speaking:
|
|
self.audio_buffer.append(audio_np)
|
|
|
|
async def _transcribe_partial(self):
|
|
"""Transcribe accumulated audio and send partial result."""
|
|
if not self.audio_buffer:
|
|
return
|
|
|
|
# Concatenate audio
|
|
audio_full = np.concatenate(self.audio_buffer)
|
|
|
|
# Transcribe asynchronously
|
|
try:
|
|
text = await whisper_transcriber.transcribe_async(
|
|
audio_full,
|
|
sample_rate=16000,
|
|
initial_prompt=self.last_transcript # Use previous for context
|
|
)
|
|
|
|
if text and text != self.last_transcript:
|
|
self.last_transcript = text
|
|
|
|
# Send partial transcript
|
|
await self.websocket.send_json({
|
|
"type": "partial",
|
|
"text": text,
|
|
"user_id": self.user_id,
|
|
"timestamp": self.timestamp_ms
|
|
})
|
|
|
|
logger.info(f"Partial [{self.user_id}]: {text}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Partial transcription failed: {e}", exc_info=True)
|
|
|
|
async def _transcribe_final(self):
|
|
"""Transcribe final accumulated audio."""
|
|
if not self.audio_buffer:
|
|
return
|
|
|
|
# Concatenate all audio
|
|
audio_full = np.concatenate(self.audio_buffer)
|
|
|
|
try:
|
|
text = await whisper_transcriber.transcribe_async(
|
|
audio_full,
|
|
sample_rate=16000
|
|
)
|
|
|
|
if text:
|
|
self.last_transcript = text
|
|
|
|
# Send final transcript
|
|
await self.websocket.send_json({
|
|
"type": "final",
|
|
"text": text,
|
|
"user_id": self.user_id,
|
|
"timestamp": self.timestamp_ms
|
|
})
|
|
|
|
logger.info(f"Final [{self.user_id}]: {text}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Final transcription failed: {e}", exc_info=True)
|
|
|
|
async def check_interruption(self, audio_data: bytes) -> bool:
|
|
"""
|
|
Check if user is interrupting (for use during Miku's speech).
|
|
|
|
Args:
|
|
audio_data: Raw PCM audio chunk
|
|
|
|
Returns:
|
|
True if interruption detected
|
|
"""
|
|
audio_np = np.frombuffer(audio_data, dtype=np.int16)
|
|
speech_prob, is_speaking = vad_processor.process_chunk(audio_np)
|
|
|
|
# Interruption: high probability sustained for threshold duration
|
|
if speech_prob > 0.7: # Higher threshold for interruption
|
|
await self.websocket.send_json({
|
|
"type": "interruption",
|
|
"probability": speech_prob,
|
|
"timestamp": self.timestamp_ms
|
|
})
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
"""Initialize models on server startup."""
|
|
global vad_processor, whisper_transcriber
|
|
|
|
logger.info("=" * 50)
|
|
logger.info("Initializing Miku STT Server")
|
|
logger.info("=" * 50)
|
|
|
|
# Initialize VAD (CPU)
|
|
logger.info("Loading Silero VAD model (CPU)...")
|
|
vad_processor = VADProcessor(
|
|
sample_rate=16000,
|
|
threshold=0.5,
|
|
min_speech_duration_ms=250, # Conservative
|
|
min_silence_duration_ms=500 # Conservative
|
|
)
|
|
logger.info("✓ VAD ready")
|
|
|
|
# Initialize Whisper (GPU with cuDNN)
|
|
logger.info("Loading Faster-Whisper model (GPU)...")
|
|
whisper_transcriber = WhisperTranscriber(
|
|
model_size="small",
|
|
device="cuda",
|
|
compute_type="float16",
|
|
language="en"
|
|
)
|
|
logger.info("✓ Whisper ready")
|
|
|
|
logger.info("=" * 50)
|
|
logger.info("STT Server ready to accept connections")
|
|
logger.info("=" * 50)
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown_event():
|
|
"""Cleanup on server shutdown."""
|
|
logger.info("Shutting down STT server...")
|
|
|
|
if whisper_transcriber:
|
|
whisper_transcriber.cleanup()
|
|
|
|
logger.info("STT server shutdown complete")
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""Health check endpoint."""
|
|
return {
|
|
"service": "Miku STT Server",
|
|
"status": "running",
|
|
"vad_ready": vad_processor is not None,
|
|
"whisper_ready": whisper_transcriber is not None,
|
|
"active_sessions": len(user_sessions)
|
|
}
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
"""Detailed health check."""
|
|
return {
|
|
"status": "healthy",
|
|
"models": {
|
|
"vad": {
|
|
"loaded": vad_processor is not None,
|
|
"device": "cpu"
|
|
},
|
|
"whisper": {
|
|
"loaded": whisper_transcriber is not None,
|
|
"model": "small",
|
|
"device": "cuda"
|
|
}
|
|
},
|
|
"sessions": {
|
|
"active": len(user_sessions),
|
|
"users": list(user_sessions.keys())
|
|
}
|
|
}
|
|
|
|
|
|
@app.websocket("/ws/stt/{user_id}")
|
|
async def websocket_stt(websocket: WebSocket, user_id: str):
|
|
"""
|
|
WebSocket endpoint for real-time STT.
|
|
|
|
Client sends: Raw PCM audio (int16, 16kHz mono, 20ms chunks)
|
|
Server sends: JSON events:
|
|
- {"type": "vad", "event": "speech_start|speaking|speech_end", ...}
|
|
- {"type": "partial", "text": "...", ...}
|
|
- {"type": "final", "text": "...", ...}
|
|
- {"type": "interruption", "probability": 0.xx}
|
|
"""
|
|
await websocket.accept()
|
|
logger.info(f"STT WebSocket connected: user {user_id}")
|
|
|
|
# Create session
|
|
session = UserSTTSession(user_id, websocket)
|
|
user_sessions[user_id] = session
|
|
|
|
try:
|
|
# Send ready message
|
|
await websocket.send_json({
|
|
"type": "ready",
|
|
"user_id": user_id,
|
|
"message": "STT session started"
|
|
})
|
|
|
|
# Main loop: receive audio chunks
|
|
while True:
|
|
# Receive binary audio data
|
|
data = await websocket.receive_bytes()
|
|
|
|
# Process audio chunk
|
|
await session.process_audio_chunk(data)
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info(f"User {user_id} disconnected")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in STT WebSocket for user {user_id}: {e}", exc_info=True)
|
|
|
|
finally:
|
|
# Cleanup session
|
|
if user_id in user_sessions:
|
|
del user_sessions[user_id]
|
|
logger.info(f"STT session ended for user {user_id}")
|
|
|
|
|
|
@app.post("/interrupt/check")
|
|
async def check_interruption(user_id: str):
|
|
"""
|
|
Check if user is interrupting (for use during Miku's speech).
|
|
|
|
Query param:
|
|
user_id: Discord user ID
|
|
|
|
Returns:
|
|
{"interrupting": bool, "probability": float}
|
|
"""
|
|
session = user_sessions.get(user_id)
|
|
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="User session not found")
|
|
|
|
# Get current VAD state
|
|
vad_state = vad_processor.get_state()
|
|
|
|
return {
|
|
"interrupting": vad_state["speaking"],
|
|
"user_id": user_id
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
|