Files
miku-discord/stt-parakeet/server/display_server.py

293 lines
11 KiB
Python

#!/usr/bin/env python3
"""
ASR WebSocket Server with Live Transcription Display
This version displays transcriptions in real-time on the server console
while clients stream audio from remote machines.
"""
import asyncio
import websockets
import numpy as np
import json
import logging
import sys
from datetime import datetime
from pathlib import Path
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from asr.asr_pipeline import ASRPipeline
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('display_server.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
class DisplayServer:
"""
WebSocket server with live transcription display.
"""
def __init__(
self,
host: str = "0.0.0.0",
port: int = 8766,
model_path: str = "models/parakeet",
sample_rate: int = 16000,
):
"""
Initialize server.
Args:
host: Host address to bind to
port: Port to bind to
model_path: Directory containing model files
sample_rate: Audio sample rate
"""
self.host = host
self.port = port
self.sample_rate = sample_rate
self.active_connections = set()
# Terminal control codes
self.CLEAR_LINE = '\033[2K'
self.CURSOR_UP = '\033[1A'
self.BOLD = '\033[1m'
self.GREEN = '\033[92m'
self.YELLOW = '\033[93m'
self.BLUE = '\033[94m'
self.RESET = '\033[0m'
# Initialize ASR pipeline
logger.info("Loading ASR model...")
self.pipeline = ASRPipeline(model_path=model_path)
logger.info("ASR Pipeline ready")
# Client sessions
self.sessions = {}
def print_header(self):
"""Print server header."""
print("\n" + "=" * 80)
print(f"{self.BOLD}{self.BLUE}ASR Server - Live Transcription Display{self.RESET}")
print("=" * 80)
print(f"Server: ws://{self.host}:{self.port}")
print(f"Sample Rate: {self.sample_rate} Hz")
print(f"Model: Parakeet TDT 0.6B V3")
print("=" * 80 + "\n")
def display_transcription(self, client_id: str, text: str, is_final: bool, is_progressive: bool = False):
"""
Display transcription in the terminal.
Args:
client_id: Client identifier
text: Transcribed text
is_final: Whether this is the final transcription
is_progressive: Whether this is a progressive update
"""
timestamp = datetime.now().strftime("%H:%M:%S")
if is_final:
# Final transcription - bold green
print(f"{self.GREEN}{self.BOLD}[{timestamp}] {client_id}{self.RESET}")
print(f"{self.GREEN} ✓ FINAL: {text}{self.RESET}\n")
elif is_progressive:
# Progressive update - yellow
print(f"{self.YELLOW}[{timestamp}] {client_id}{self.RESET}")
print(f"{self.YELLOW}{text}{self.RESET}\n")
else:
# Regular transcription
print(f"{self.BLUE}[{timestamp}] {client_id}{self.RESET}")
print(f" {text}\n")
# Flush to ensure immediate display
sys.stdout.flush()
async def handle_client(self, websocket):
"""
Handle individual WebSocket client connection.
Args:
websocket: WebSocket connection
"""
client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}"
logger.info(f"Client connected: {client_id}")
self.active_connections.add(websocket)
# Display connection
print(f"\n{self.BOLD}{'='*80}{self.RESET}")
print(f"{self.GREEN}✓ Client connected: {client_id}{self.RESET}")
print(f"{self.BOLD}{'='*80}{self.RESET}\n")
sys.stdout.flush()
# Audio buffer for accumulating ALL audio
all_audio = []
last_transcribed_samples = 0
# For progressive transcription
min_chunk_duration = 2.0 # Minimum 2 seconds before transcribing
min_chunk_samples = int(self.sample_rate * min_chunk_duration)
try:
# Send welcome message
await websocket.send(json.dumps({
"type": "info",
"message": "Connected to ASR server with live display",
"sample_rate": self.sample_rate,
}))
async for message in websocket:
try:
if isinstance(message, bytes):
# Binary audio data
audio_data = np.frombuffer(message, dtype=np.int16)
audio_data = audio_data.astype(np.float32) / 32768.0
# Accumulate all audio
all_audio.append(audio_data)
total_samples = sum(len(chunk) for chunk in all_audio)
# Transcribe periodically when we have enough NEW audio
samples_since_last = total_samples - last_transcribed_samples
if samples_since_last >= min_chunk_samples:
audio_chunk = np.concatenate(all_audio)
last_transcribed_samples = total_samples
# Transcribe the accumulated audio
try:
text = self.pipeline.transcribe(
audio_chunk,
sample_rate=self.sample_rate
)
if text and text.strip():
# Display on server
self.display_transcription(client_id, text, is_final=False, is_progressive=True)
# Send to client
response = {
"type": "transcript",
"text": text,
"is_final": False,
}
await websocket.send(json.dumps(response))
except Exception as e:
logger.error(f"Transcription error: {e}")
await websocket.send(json.dumps({
"type": "error",
"message": f"Transcription failed: {str(e)}"
}))
elif isinstance(message, str):
# JSON command
try:
command = json.loads(message)
if command.get("type") == "final":
# Process all accumulated audio (final transcription)
if all_audio:
audio_chunk = np.concatenate(all_audio)
text = self.pipeline.transcribe(
audio_chunk,
sample_rate=self.sample_rate
)
if text and text.strip():
# Display on server
self.display_transcription(client_id, text, is_final=True)
# Send to client
response = {
"type": "transcript",
"text": text,
"is_final": True,
}
await websocket.send(json.dumps(response))
# Clear buffer after final transcription
all_audio = []
last_transcribed_samples = 0
elif command.get("type") == "reset":
# Reset buffer
all_audio = []
last_transcribed_samples = 0
await websocket.send(json.dumps({
"type": "info",
"message": "Buffer reset"
}))
print(f"{self.YELLOW}[{client_id}] Buffer reset{self.RESET}\n")
sys.stdout.flush()
except json.JSONDecodeError:
logger.warning(f"Invalid JSON from {client_id}: {message}")
except Exception as e:
logger.error(f"Error processing message from {client_id}: {e}")
break
except websockets.exceptions.ConnectionClosed:
logger.info(f"Connection closed: {client_id}")
except Exception as e:
logger.error(f"Unexpected error with {client_id}: {e}")
finally:
self.active_connections.discard(websocket)
print(f"\n{self.BOLD}{'='*80}{self.RESET}")
print(f"{self.YELLOW}✗ Client disconnected: {client_id}{self.RESET}")
print(f"{self.BOLD}{'='*80}{self.RESET}\n")
sys.stdout.flush()
logger.info(f"Connection closed: {client_id}")
async def start(self):
"""Start the WebSocket server."""
self.print_header()
async with websockets.serve(self.handle_client, self.host, self.port):
logger.info(f"Starting WebSocket server on {self.host}:{self.port}")
print(f"{self.GREEN}{self.BOLD}Server is running and ready for connections!{self.RESET}")
print(f"{self.BOLD}Waiting for clients...{self.RESET}\n")
sys.stdout.flush()
# Keep server running
await asyncio.Future()
def main():
"""Main entry point."""
import argparse
parser = argparse.ArgumentParser(description="ASR Server with Live Display")
parser.add_argument("--host", default="0.0.0.0", help="Host address")
parser.add_argument("--port", type=int, default=8766, help="Port number")
parser.add_argument("--model-path", default="models/parakeet", help="Model directory")
parser.add_argument("--sample-rate", type=int, default=16000, help="Sample rate")
args = parser.parse_args()
server = DisplayServer(
host=args.host,
port=args.port,
model_path=args.model_path,
sample_rate=args.sample_rate,
)
try:
asyncio.run(server.start())
except KeyboardInterrupt:
print(f"\n\n{server.YELLOW}Server stopped by user{server.RESET}")
logger.info("Server stopped by user")
if __name__ == "__main__":
main()