397 lines
13 KiB
Python
397 lines
13 KiB
Python
"""
|
|
STT Server
|
|
|
|
FastAPI WebSocket server for real-time speech-to-text.
|
|
Combines Silero VAD (CPU) and NVIDIA Parakeet (GPU) for efficient transcription.
|
|
|
|
Architecture:
|
|
- VAD runs continuously on every audio chunk (CPU)
|
|
- Parakeet transcribes only when VAD detects speech (GPU)
|
|
- Supports multiple concurrent users
|
|
- Sends partial and final transcripts via WebSocket with word-level tokens
|
|
"""
|
|
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
|
|
from fastapi.responses import JSONResponse
|
|
import numpy as np
|
|
import asyncio
|
|
import logging
|
|
from typing import Dict, Optional
|
|
from datetime import datetime
|
|
|
|
from vad_processor import VADProcessor
|
|
from parakeet_transcriber import ParakeetTranscriber
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='[%(levelname)s] [%(name)s] %(message)s'
|
|
)
|
|
logger = logging.getLogger('stt_server')
|
|
|
|
# Initialize FastAPI app
|
|
app = FastAPI(title="Miku STT Server", version="1.0.0")
|
|
|
|
# Global instances (initialized on startup)
|
|
vad_processor: Optional[VADProcessor] = None
|
|
parakeet_transcriber: Optional[ParakeetTranscriber] = None
|
|
|
|
# User session tracking
|
|
user_sessions: Dict[str, dict] = {}
|
|
|
|
|
|
class UserSTTSession:
|
|
"""Manages STT state for a single user."""
|
|
|
|
def __init__(self, user_id: str, websocket: WebSocket):
|
|
self.user_id = user_id
|
|
self.websocket = websocket
|
|
self.audio_buffer = []
|
|
self.is_speaking = False
|
|
self.timestamp_ms = 0.0
|
|
self.transcript_buffer = []
|
|
self.last_transcript = ""
|
|
self.last_partial_duration = 0.0 # Track when we last sent a partial
|
|
self.last_speech_timestamp = 0.0 # Track last time we detected speech
|
|
self.speech_timeout_ms = 3000 # Force finalization after 3s of no new speech
|
|
|
|
logger.info(f"Created STT session for user {user_id}")
|
|
|
|
async def process_audio_chunk(self, audio_data: bytes):
|
|
"""
|
|
Process incoming audio chunk.
|
|
|
|
Args:
|
|
audio_data: Raw PCM audio (int16, 16kHz mono)
|
|
"""
|
|
# Convert bytes to numpy array (int16)
|
|
audio_np = np.frombuffer(audio_data, dtype=np.int16)
|
|
|
|
# Calculate timestamp (assuming 16kHz, 20ms chunks = 320 samples)
|
|
chunk_duration_ms = (len(audio_np) / 16000) * 1000
|
|
self.timestamp_ms += chunk_duration_ms
|
|
|
|
# Run VAD on chunk
|
|
vad_event = vad_processor.detect_speech_segment(audio_np, self.timestamp_ms)
|
|
|
|
if vad_event:
|
|
event_type = vad_event["event"]
|
|
probability = vad_event["probability"]
|
|
|
|
logger.debug(f"VAD event for user {self.user_id}: {event_type} (prob={probability:.3f})")
|
|
|
|
# Send VAD event to client
|
|
await self.websocket.send_json({
|
|
"type": "vad",
|
|
"event": event_type,
|
|
"speaking": event_type in ["speech_start", "speaking"],
|
|
"probability": probability,
|
|
"timestamp": self.timestamp_ms
|
|
})
|
|
|
|
# Handle speech events
|
|
if event_type == "speech_start":
|
|
self.is_speaking = True
|
|
self.audio_buffer = [audio_np]
|
|
self.last_partial_duration = 0.0
|
|
self.last_speech_timestamp = self.timestamp_ms
|
|
logger.info(f"[STT] User {self.user_id} SPEECH START")
|
|
|
|
elif event_type == "speaking":
|
|
if self.is_speaking:
|
|
self.audio_buffer.append(audio_np)
|
|
self.last_speech_timestamp = self.timestamp_ms # Update speech timestamp
|
|
|
|
# Transcribe partial every ~1 second for streaming (reduced from 2s)
|
|
total_samples = sum(len(chunk) for chunk in self.audio_buffer)
|
|
duration_s = total_samples / 16000
|
|
|
|
# More frequent partials for better responsiveness
|
|
if duration_s >= 1.0:
|
|
logger.debug(f"Triggering partial transcription at {duration_s:.1f}s")
|
|
await self._transcribe_partial()
|
|
# Keep buffer for final transcription, but mark progress
|
|
self.last_partial_duration = duration_s
|
|
|
|
elif event_type == "speech_end":
|
|
self.is_speaking = False
|
|
|
|
logger.info(f"[STT] User {self.user_id} SPEECH END (VAD detected) - transcribing final")
|
|
|
|
# Transcribe final
|
|
await self._transcribe_final()
|
|
|
|
# Clear buffer
|
|
self.audio_buffer = []
|
|
self.last_partial_duration = 0.0
|
|
logger.debug(f"User {self.user_id} stopped speaking")
|
|
|
|
else:
|
|
# No VAD event - still accumulate audio if speaking
|
|
if self.is_speaking:
|
|
self.audio_buffer.append(audio_np)
|
|
|
|
# Check for timeout
|
|
time_since_speech = self.timestamp_ms - self.last_speech_timestamp
|
|
|
|
if time_since_speech >= self.speech_timeout_ms:
|
|
# Timeout - user probably stopped but VAD didn't detect it
|
|
logger.warning(f"[STT] User {self.user_id} SPEECH TIMEOUT after {time_since_speech:.0f}ms - forcing finalization")
|
|
self.is_speaking = False
|
|
|
|
# Force final transcription
|
|
await self._transcribe_final()
|
|
|
|
# Clear buffer
|
|
self.audio_buffer = []
|
|
self.last_partial_duration = 0.0
|
|
|
|
async def _transcribe_partial(self):
|
|
"""Transcribe accumulated audio and send partial result (no timestamps to save VRAM)."""
|
|
if not self.audio_buffer:
|
|
return
|
|
|
|
# Concatenate audio
|
|
audio_full = np.concatenate(self.audio_buffer)
|
|
|
|
# Transcribe asynchronously WITHOUT timestamps for partials (saves 1-2GB VRAM)
|
|
try:
|
|
result = await parakeet_transcriber.transcribe_async(
|
|
audio_full,
|
|
sample_rate=16000,
|
|
return_timestamps=False # Disable timestamps for partials to reduce VRAM usage
|
|
)
|
|
|
|
# Result is just a string when timestamps=False
|
|
text = result if isinstance(result, str) else result.get("text", "")
|
|
|
|
if text and text != self.last_transcript:
|
|
self.last_transcript = text
|
|
|
|
# Send partial transcript without word tokens (saves memory)
|
|
await self.websocket.send_json({
|
|
"type": "partial",
|
|
"text": text,
|
|
"words": [], # No word tokens for partials
|
|
"user_id": self.user_id,
|
|
"timestamp": self.timestamp_ms
|
|
})
|
|
|
|
logger.info(f"Partial [{self.user_id}]: {text}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Partial transcription failed: {e}", exc_info=True)
|
|
|
|
async def _transcribe_final(self):
|
|
"""Transcribe final accumulated audio with word tokens."""
|
|
if not self.audio_buffer:
|
|
return
|
|
|
|
# Concatenate all audio
|
|
audio_full = np.concatenate(self.audio_buffer)
|
|
|
|
try:
|
|
result = await parakeet_transcriber.transcribe_async(
|
|
audio_full,
|
|
sample_rate=16000,
|
|
return_timestamps=True
|
|
)
|
|
|
|
if result and result.get("text"):
|
|
self.last_transcript = result["text"]
|
|
|
|
# Send final transcript with word tokens
|
|
await self.websocket.send_json({
|
|
"type": "final",
|
|
"text": result["text"],
|
|
"words": result.get("words", []), # Word-level tokens for LLM
|
|
"user_id": self.user_id,
|
|
"timestamp": self.timestamp_ms
|
|
})
|
|
|
|
logger.info(f"Final [{self.user_id}]: {result['text']}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Final transcription failed: {e}", exc_info=True)
|
|
|
|
async def check_interruption(self, audio_data: bytes) -> bool:
|
|
"""
|
|
Check if user is interrupting (for use during Miku's speech).
|
|
|
|
Args:
|
|
audio_data: Raw PCM audio chunk
|
|
|
|
Returns:
|
|
True if interruption detected
|
|
"""
|
|
audio_np = np.frombuffer(audio_data, dtype=np.int16)
|
|
speech_prob, is_speaking = vad_processor.process_chunk(audio_np)
|
|
|
|
# Interruption: high probability sustained for threshold duration
|
|
if speech_prob > 0.7: # Higher threshold for interruption
|
|
await self.websocket.send_json({
|
|
"type": "interruption",
|
|
"probability": speech_prob,
|
|
"timestamp": self.timestamp_ms
|
|
})
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
"""Initialize models on server startup."""
|
|
global vad_processor, parakeet_transcriber
|
|
|
|
logger.info("=" * 50)
|
|
logger.info("Initializing Miku STT Server")
|
|
logger.info("=" * 50)
|
|
|
|
# Initialize VAD (CPU)
|
|
logger.info("Loading Silero VAD model (CPU)...")
|
|
vad_processor = VADProcessor(
|
|
sample_rate=16000,
|
|
threshold=0.5,
|
|
min_speech_duration_ms=250, # Conservative - wait 250ms before starting
|
|
min_silence_duration_ms=300 # Reduced from 500ms - detect silence faster
|
|
)
|
|
logger.info("✓ VAD ready")
|
|
|
|
# Initialize Parakeet (GPU)
|
|
logger.info("Loading NVIDIA Parakeet TDT model (GPU)...")
|
|
parakeet_transcriber = ParakeetTranscriber(
|
|
model_name="nvidia/parakeet-tdt-0.6b-v3",
|
|
device="cuda",
|
|
language="en"
|
|
)
|
|
logger.info("✓ Parakeet ready")
|
|
|
|
logger.info("=" * 50)
|
|
logger.info("STT Server ready to accept connections")
|
|
logger.info("=" * 50)
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown_event():
|
|
"""Cleanup on server shutdown."""
|
|
logger.info("Shutting down STT server...")
|
|
|
|
if parakeet_transcriber:
|
|
parakeet_transcriber.cleanup()
|
|
|
|
logger.info("STT server shutdown complete")
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""Health check endpoint."""
|
|
return {
|
|
"service": "Miku STT Server",
|
|
"status": "running",
|
|
"vad_ready": vad_processor is not None,
|
|
"whisper_ready": whisper_transcriber is not None,
|
|
"active_sessions": len(user_sessions)
|
|
}
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
"""Detailed health check."""
|
|
return {
|
|
"status": "healthy",
|
|
"models": {
|
|
"vad": {
|
|
"loaded": vad_processor is not None,
|
|
"device": "cpu"
|
|
},
|
|
"whisper": {
|
|
"loaded": whisper_transcriber is not None,
|
|
"model": "small",
|
|
"device": "cuda"
|
|
}
|
|
},
|
|
"sessions": {
|
|
"active": len(user_sessions),
|
|
"users": list(user_sessions.keys())
|
|
}
|
|
}
|
|
|
|
|
|
@app.websocket("/ws/stt/{user_id}")
|
|
async def websocket_stt(websocket: WebSocket, user_id: str):
|
|
"""
|
|
WebSocket endpoint for real-time STT.
|
|
|
|
Client sends: Raw PCM audio (int16, 16kHz mono, 20ms chunks)
|
|
Server sends: JSON events:
|
|
- {"type": "vad", "event": "speech_start|speaking|speech_end", ...}
|
|
- {"type": "partial", "text": "...", ...}
|
|
- {"type": "final", "text": "...", ...}
|
|
- {"type": "interruption", "probability": 0.xx}
|
|
"""
|
|
await websocket.accept()
|
|
logger.info(f"STT WebSocket connected: user {user_id}")
|
|
|
|
# Create session
|
|
session = UserSTTSession(user_id, websocket)
|
|
user_sessions[user_id] = session
|
|
|
|
try:
|
|
# Send ready message
|
|
await websocket.send_json({
|
|
"type": "ready",
|
|
"user_id": user_id,
|
|
"message": "STT session started"
|
|
})
|
|
|
|
# Main loop: receive audio chunks
|
|
while True:
|
|
# Receive binary audio data
|
|
data = await websocket.receive_bytes()
|
|
|
|
# Process audio chunk
|
|
await session.process_audio_chunk(data)
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info(f"User {user_id} disconnected")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in STT WebSocket for user {user_id}: {e}", exc_info=True)
|
|
|
|
finally:
|
|
# Cleanup session
|
|
if user_id in user_sessions:
|
|
del user_sessions[user_id]
|
|
logger.info(f"STT session ended for user {user_id}")
|
|
|
|
|
|
@app.post("/interrupt/check")
|
|
async def check_interruption(user_id: str):
|
|
"""
|
|
Check if user is interrupting (for use during Miku's speech).
|
|
|
|
Query param:
|
|
user_id: Discord user ID
|
|
|
|
Returns:
|
|
{"interrupting": bool, "probability": float}
|
|
"""
|
|
session = user_sessions.get(user_id)
|
|
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="User session not found")
|
|
|
|
# Get current VAD state
|
|
vad_state = vad_processor.get_state()
|
|
|
|
return {
|
|
"interrupting": vad_state["speaking"],
|
|
"user_id": user_id
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
|