Phase 4 STT pipeline implemented — Silero VAD + faster-whisper — still not working well at all
This commit is contained in:
361
stt/stt_server.py
Normal file
361
stt/stt_server.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
STT Server
|
||||
|
||||
FastAPI WebSocket server for real-time speech-to-text.
|
||||
Combines Silero VAD (CPU) and Faster-Whisper (GPU) for efficient transcription.
|
||||
|
||||
Architecture:
|
||||
- VAD runs continuously on every audio chunk (CPU)
|
||||
- Whisper transcribes only when VAD detects speech (GPU)
|
||||
- Supports multiple concurrent users
|
||||
- Sends partial and final transcripts via WebSocket
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
import numpy as np
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from vad_processor import VADProcessor
|
||||
from whisper_transcriber import WhisperTranscriber
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='[%(levelname)s] [%(name)s] %(message)s'
|
||||
)
|
||||
logger = logging.getLogger('stt_server')
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(title="Miku STT Server", version="1.0.0")
|
||||
|
||||
# Global instances (initialized on startup)
|
||||
vad_processor: Optional[VADProcessor] = None
|
||||
whisper_transcriber: Optional[WhisperTranscriber] = None
|
||||
|
||||
# User session tracking
|
||||
user_sessions: Dict[str, dict] = {}
|
||||
|
||||
|
||||
class UserSTTSession:
|
||||
"""Manages STT state for a single user."""
|
||||
|
||||
def __init__(self, user_id: str, websocket: WebSocket):
|
||||
self.user_id = user_id
|
||||
self.websocket = websocket
|
||||
self.audio_buffer = []
|
||||
self.is_speaking = False
|
||||
self.timestamp_ms = 0.0
|
||||
self.transcript_buffer = []
|
||||
self.last_transcript = ""
|
||||
|
||||
logger.info(f"Created STT session for user {user_id}")
|
||||
|
||||
async def process_audio_chunk(self, audio_data: bytes):
|
||||
"""
|
||||
Process incoming audio chunk.
|
||||
|
||||
Args:
|
||||
audio_data: Raw PCM audio (int16, 16kHz mono)
|
||||
"""
|
||||
# Convert bytes to numpy array (int16)
|
||||
audio_np = np.frombuffer(audio_data, dtype=np.int16)
|
||||
|
||||
# Calculate timestamp (assuming 16kHz, 20ms chunks = 320 samples)
|
||||
chunk_duration_ms = (len(audio_np) / 16000) * 1000
|
||||
self.timestamp_ms += chunk_duration_ms
|
||||
|
||||
# Run VAD on chunk
|
||||
vad_event = vad_processor.detect_speech_segment(audio_np, self.timestamp_ms)
|
||||
|
||||
if vad_event:
|
||||
event_type = vad_event["event"]
|
||||
probability = vad_event["probability"]
|
||||
|
||||
# Send VAD event to client
|
||||
await self.websocket.send_json({
|
||||
"type": "vad",
|
||||
"event": event_type,
|
||||
"speaking": event_type in ["speech_start", "speaking"],
|
||||
"probability": probability,
|
||||
"timestamp": self.timestamp_ms
|
||||
})
|
||||
|
||||
# Handle speech events
|
||||
if event_type == "speech_start":
|
||||
self.is_speaking = True
|
||||
self.audio_buffer = [audio_np]
|
||||
logger.debug(f"User {self.user_id} started speaking")
|
||||
|
||||
elif event_type == "speaking":
|
||||
if self.is_speaking:
|
||||
self.audio_buffer.append(audio_np)
|
||||
|
||||
# Transcribe partial every ~2 seconds for streaming
|
||||
total_samples = sum(len(chunk) for chunk in self.audio_buffer)
|
||||
duration_s = total_samples / 16000
|
||||
|
||||
if duration_s >= 2.0:
|
||||
await self._transcribe_partial()
|
||||
|
||||
elif event_type == "speech_end":
|
||||
self.is_speaking = False
|
||||
|
||||
# Transcribe final
|
||||
await self._transcribe_final()
|
||||
|
||||
# Clear buffer
|
||||
self.audio_buffer = []
|
||||
logger.debug(f"User {self.user_id} stopped speaking")
|
||||
|
||||
else:
|
||||
# Still accumulate audio if speaking
|
||||
if self.is_speaking:
|
||||
self.audio_buffer.append(audio_np)
|
||||
|
||||
async def _transcribe_partial(self):
|
||||
"""Transcribe accumulated audio and send partial result."""
|
||||
if not self.audio_buffer:
|
||||
return
|
||||
|
||||
# Concatenate audio
|
||||
audio_full = np.concatenate(self.audio_buffer)
|
||||
|
||||
# Transcribe asynchronously
|
||||
try:
|
||||
text = await whisper_transcriber.transcribe_async(
|
||||
audio_full,
|
||||
sample_rate=16000,
|
||||
initial_prompt=self.last_transcript # Use previous for context
|
||||
)
|
||||
|
||||
if text and text != self.last_transcript:
|
||||
self.last_transcript = text
|
||||
|
||||
# Send partial transcript
|
||||
await self.websocket.send_json({
|
||||
"type": "partial",
|
||||
"text": text,
|
||||
"user_id": self.user_id,
|
||||
"timestamp": self.timestamp_ms
|
||||
})
|
||||
|
||||
logger.info(f"Partial [{self.user_id}]: {text}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Partial transcription failed: {e}", exc_info=True)
|
||||
|
||||
async def _transcribe_final(self):
|
||||
"""Transcribe final accumulated audio."""
|
||||
if not self.audio_buffer:
|
||||
return
|
||||
|
||||
# Concatenate all audio
|
||||
audio_full = np.concatenate(self.audio_buffer)
|
||||
|
||||
try:
|
||||
text = await whisper_transcriber.transcribe_async(
|
||||
audio_full,
|
||||
sample_rate=16000
|
||||
)
|
||||
|
||||
if text:
|
||||
self.last_transcript = text
|
||||
|
||||
# Send final transcript
|
||||
await self.websocket.send_json({
|
||||
"type": "final",
|
||||
"text": text,
|
||||
"user_id": self.user_id,
|
||||
"timestamp": self.timestamp_ms
|
||||
})
|
||||
|
||||
logger.info(f"Final [{self.user_id}]: {text}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Final transcription failed: {e}", exc_info=True)
|
||||
|
||||
async def check_interruption(self, audio_data: bytes) -> bool:
|
||||
"""
|
||||
Check if user is interrupting (for use during Miku's speech).
|
||||
|
||||
Args:
|
||||
audio_data: Raw PCM audio chunk
|
||||
|
||||
Returns:
|
||||
True if interruption detected
|
||||
"""
|
||||
audio_np = np.frombuffer(audio_data, dtype=np.int16)
|
||||
speech_prob, is_speaking = vad_processor.process_chunk(audio_np)
|
||||
|
||||
# Interruption: high probability sustained for threshold duration
|
||||
if speech_prob > 0.7: # Higher threshold for interruption
|
||||
await self.websocket.send_json({
|
||||
"type": "interruption",
|
||||
"probability": speech_prob,
|
||||
"timestamp": self.timestamp_ms
|
||||
})
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Initialize models on server startup."""
|
||||
global vad_processor, whisper_transcriber
|
||||
|
||||
logger.info("=" * 50)
|
||||
logger.info("Initializing Miku STT Server")
|
||||
logger.info("=" * 50)
|
||||
|
||||
# Initialize VAD (CPU)
|
||||
logger.info("Loading Silero VAD model (CPU)...")
|
||||
vad_processor = VADProcessor(
|
||||
sample_rate=16000,
|
||||
threshold=0.5,
|
||||
min_speech_duration_ms=250, # Conservative
|
||||
min_silence_duration_ms=500 # Conservative
|
||||
)
|
||||
logger.info("✓ VAD ready")
|
||||
|
||||
# Initialize Whisper (GPU with cuDNN)
|
||||
logger.info("Loading Faster-Whisper model (GPU)...")
|
||||
whisper_transcriber = WhisperTranscriber(
|
||||
model_size="small",
|
||||
device="cuda",
|
||||
compute_type="float16",
|
||||
language="en"
|
||||
)
|
||||
logger.info("✓ Whisper ready")
|
||||
|
||||
logger.info("=" * 50)
|
||||
logger.info("STT Server ready to accept connections")
|
||||
logger.info("=" * 50)
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Cleanup on server shutdown."""
|
||||
logger.info("Shutting down STT server...")
|
||||
|
||||
if whisper_transcriber:
|
||||
whisper_transcriber.cleanup()
|
||||
|
||||
logger.info("STT server shutdown complete")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Health check endpoint."""
|
||||
return {
|
||||
"service": "Miku STT Server",
|
||||
"status": "running",
|
||||
"vad_ready": vad_processor is not None,
|
||||
"whisper_ready": whisper_transcriber is not None,
|
||||
"active_sessions": len(user_sessions)
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Detailed health check."""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"models": {
|
||||
"vad": {
|
||||
"loaded": vad_processor is not None,
|
||||
"device": "cpu"
|
||||
},
|
||||
"whisper": {
|
||||
"loaded": whisper_transcriber is not None,
|
||||
"model": "small",
|
||||
"device": "cuda"
|
||||
}
|
||||
},
|
||||
"sessions": {
|
||||
"active": len(user_sessions),
|
||||
"users": list(user_sessions.keys())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.websocket("/ws/stt/{user_id}")
|
||||
async def websocket_stt(websocket: WebSocket, user_id: str):
|
||||
"""
|
||||
WebSocket endpoint for real-time STT.
|
||||
|
||||
Client sends: Raw PCM audio (int16, 16kHz mono, 20ms chunks)
|
||||
Server sends: JSON events:
|
||||
- {"type": "vad", "event": "speech_start|speaking|speech_end", ...}
|
||||
- {"type": "partial", "text": "...", ...}
|
||||
- {"type": "final", "text": "...", ...}
|
||||
- {"type": "interruption", "probability": 0.xx}
|
||||
"""
|
||||
await websocket.accept()
|
||||
logger.info(f"STT WebSocket connected: user {user_id}")
|
||||
|
||||
# Create session
|
||||
session = UserSTTSession(user_id, websocket)
|
||||
user_sessions[user_id] = session
|
||||
|
||||
try:
|
||||
# Send ready message
|
||||
await websocket.send_json({
|
||||
"type": "ready",
|
||||
"user_id": user_id,
|
||||
"message": "STT session started"
|
||||
})
|
||||
|
||||
# Main loop: receive audio chunks
|
||||
while True:
|
||||
# Receive binary audio data
|
||||
data = await websocket.receive_bytes()
|
||||
|
||||
# Process audio chunk
|
||||
await session.process_audio_chunk(data)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"User {user_id} disconnected")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in STT WebSocket for user {user_id}: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
# Cleanup session
|
||||
if user_id in user_sessions:
|
||||
del user_sessions[user_id]
|
||||
logger.info(f"STT session ended for user {user_id}")
|
||||
|
||||
|
||||
@app.post("/interrupt/check")
|
||||
async def check_interruption(user_id: str):
|
||||
"""
|
||||
Check if user is interrupting (for use during Miku's speech).
|
||||
|
||||
Query param:
|
||||
user_id: Discord user ID
|
||||
|
||||
Returns:
|
||||
{"interrupting": bool, "probability": float}
|
||||
"""
|
||||
session = user_sessions.get(user_id)
|
||||
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="User session not found")
|
||||
|
||||
# Get current VAD state
|
||||
vad_state = vad_processor.get_state()
|
||||
|
||||
return {
|
||||
"interrupting": vad_state["speaking"],
|
||||
"user_id": user_id
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
|
||||
Reference in New Issue
Block a user