293 lines
11 KiB
Python
293 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
ASR WebSocket Server with Live Transcription Display
|
|
|
|
This version displays transcriptions in real-time on the server console
|
|
while clients stream audio from remote machines.
|
|
"""
|
|
import asyncio
|
|
import websockets
|
|
import numpy as np
|
|
import json
|
|
import logging
|
|
import sys
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
# Add project root to path
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
from asr.asr_pipeline import ASRPipeline
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler('display_server.log'),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DisplayServer:
|
|
"""
|
|
WebSocket server with live transcription display.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
host: str = "0.0.0.0",
|
|
port: int = 8766,
|
|
model_path: str = "models/parakeet",
|
|
sample_rate: int = 16000,
|
|
):
|
|
"""
|
|
Initialize server.
|
|
|
|
Args:
|
|
host: Host address to bind to
|
|
port: Port to bind to
|
|
model_path: Directory containing model files
|
|
sample_rate: Audio sample rate
|
|
"""
|
|
self.host = host
|
|
self.port = port
|
|
self.sample_rate = sample_rate
|
|
self.active_connections = set()
|
|
|
|
# Terminal control codes
|
|
self.CLEAR_LINE = '\033[2K'
|
|
self.CURSOR_UP = '\033[1A'
|
|
self.BOLD = '\033[1m'
|
|
self.GREEN = '\033[92m'
|
|
self.YELLOW = '\033[93m'
|
|
self.BLUE = '\033[94m'
|
|
self.RESET = '\033[0m'
|
|
|
|
# Initialize ASR pipeline
|
|
logger.info("Loading ASR model...")
|
|
self.pipeline = ASRPipeline(model_path=model_path)
|
|
logger.info("ASR Pipeline ready")
|
|
|
|
# Client sessions
|
|
self.sessions = {}
|
|
|
|
def print_header(self):
|
|
"""Print server header."""
|
|
print("\n" + "=" * 80)
|
|
print(f"{self.BOLD}{self.BLUE}ASR Server - Live Transcription Display{self.RESET}")
|
|
print("=" * 80)
|
|
print(f"Server: ws://{self.host}:{self.port}")
|
|
print(f"Sample Rate: {self.sample_rate} Hz")
|
|
print(f"Model: Parakeet TDT 0.6B V3")
|
|
print("=" * 80 + "\n")
|
|
|
|
def display_transcription(self, client_id: str, text: str, is_final: bool, is_progressive: bool = False):
|
|
"""
|
|
Display transcription in the terminal.
|
|
|
|
Args:
|
|
client_id: Client identifier
|
|
text: Transcribed text
|
|
is_final: Whether this is the final transcription
|
|
is_progressive: Whether this is a progressive update
|
|
"""
|
|
timestamp = datetime.now().strftime("%H:%M:%S")
|
|
|
|
if is_final:
|
|
# Final transcription - bold green
|
|
print(f"{self.GREEN}{self.BOLD}[{timestamp}] {client_id}{self.RESET}")
|
|
print(f"{self.GREEN} ✓ FINAL: {text}{self.RESET}\n")
|
|
elif is_progressive:
|
|
# Progressive update - yellow
|
|
print(f"{self.YELLOW}[{timestamp}] {client_id}{self.RESET}")
|
|
print(f"{self.YELLOW} → {text}{self.RESET}\n")
|
|
else:
|
|
# Regular transcription
|
|
print(f"{self.BLUE}[{timestamp}] {client_id}{self.RESET}")
|
|
print(f" {text}\n")
|
|
|
|
# Flush to ensure immediate display
|
|
sys.stdout.flush()
|
|
|
|
async def handle_client(self, websocket):
|
|
"""
|
|
Handle individual WebSocket client connection.
|
|
|
|
Args:
|
|
websocket: WebSocket connection
|
|
"""
|
|
client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}"
|
|
logger.info(f"Client connected: {client_id}")
|
|
self.active_connections.add(websocket)
|
|
|
|
# Display connection
|
|
print(f"\n{self.BOLD}{'='*80}{self.RESET}")
|
|
print(f"{self.GREEN}✓ Client connected: {client_id}{self.RESET}")
|
|
print(f"{self.BOLD}{'='*80}{self.RESET}\n")
|
|
sys.stdout.flush()
|
|
|
|
# Audio buffer for accumulating ALL audio
|
|
all_audio = []
|
|
last_transcribed_samples = 0
|
|
|
|
# For progressive transcription
|
|
min_chunk_duration = 2.0 # Minimum 2 seconds before transcribing
|
|
min_chunk_samples = int(self.sample_rate * min_chunk_duration)
|
|
|
|
try:
|
|
# Send welcome message
|
|
await websocket.send(json.dumps({
|
|
"type": "info",
|
|
"message": "Connected to ASR server with live display",
|
|
"sample_rate": self.sample_rate,
|
|
}))
|
|
|
|
async for message in websocket:
|
|
try:
|
|
if isinstance(message, bytes):
|
|
# Binary audio data
|
|
audio_data = np.frombuffer(message, dtype=np.int16)
|
|
audio_data = audio_data.astype(np.float32) / 32768.0
|
|
|
|
# Accumulate all audio
|
|
all_audio.append(audio_data)
|
|
total_samples = sum(len(chunk) for chunk in all_audio)
|
|
|
|
# Transcribe periodically when we have enough NEW audio
|
|
samples_since_last = total_samples - last_transcribed_samples
|
|
if samples_since_last >= min_chunk_samples:
|
|
audio_chunk = np.concatenate(all_audio)
|
|
last_transcribed_samples = total_samples
|
|
|
|
# Transcribe the accumulated audio
|
|
try:
|
|
text = self.pipeline.transcribe(
|
|
audio_chunk,
|
|
sample_rate=self.sample_rate
|
|
)
|
|
|
|
if text and text.strip():
|
|
# Display on server
|
|
self.display_transcription(client_id, text, is_final=False, is_progressive=True)
|
|
|
|
# Send to client
|
|
response = {
|
|
"type": "transcript",
|
|
"text": text,
|
|
"is_final": False,
|
|
}
|
|
await websocket.send(json.dumps(response))
|
|
except Exception as e:
|
|
logger.error(f"Transcription error: {e}")
|
|
await websocket.send(json.dumps({
|
|
"type": "error",
|
|
"message": f"Transcription failed: {str(e)}"
|
|
}))
|
|
|
|
elif isinstance(message, str):
|
|
# JSON command
|
|
try:
|
|
command = json.loads(message)
|
|
|
|
if command.get("type") == "final":
|
|
# Process all accumulated audio (final transcription)
|
|
if all_audio:
|
|
audio_chunk = np.concatenate(all_audio)
|
|
|
|
text = self.pipeline.transcribe(
|
|
audio_chunk,
|
|
sample_rate=self.sample_rate
|
|
)
|
|
|
|
if text and text.strip():
|
|
# Display on server
|
|
self.display_transcription(client_id, text, is_final=True)
|
|
|
|
# Send to client
|
|
response = {
|
|
"type": "transcript",
|
|
"text": text,
|
|
"is_final": True,
|
|
}
|
|
await websocket.send(json.dumps(response))
|
|
|
|
# Clear buffer after final transcription
|
|
all_audio = []
|
|
last_transcribed_samples = 0
|
|
|
|
elif command.get("type") == "reset":
|
|
# Reset buffer
|
|
all_audio = []
|
|
last_transcribed_samples = 0
|
|
await websocket.send(json.dumps({
|
|
"type": "info",
|
|
"message": "Buffer reset"
|
|
}))
|
|
print(f"{self.YELLOW}[{client_id}] Buffer reset{self.RESET}\n")
|
|
sys.stdout.flush()
|
|
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Invalid JSON from {client_id}: {message}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing message from {client_id}: {e}")
|
|
break
|
|
|
|
except websockets.exceptions.ConnectionClosed:
|
|
logger.info(f"Connection closed: {client_id}")
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error with {client_id}: {e}")
|
|
finally:
|
|
self.active_connections.discard(websocket)
|
|
print(f"\n{self.BOLD}{'='*80}{self.RESET}")
|
|
print(f"{self.YELLOW}✗ Client disconnected: {client_id}{self.RESET}")
|
|
print(f"{self.BOLD}{'='*80}{self.RESET}\n")
|
|
sys.stdout.flush()
|
|
logger.info(f"Connection closed: {client_id}")
|
|
|
|
async def start(self):
|
|
"""Start the WebSocket server."""
|
|
self.print_header()
|
|
|
|
async with websockets.serve(self.handle_client, self.host, self.port):
|
|
logger.info(f"Starting WebSocket server on {self.host}:{self.port}")
|
|
print(f"{self.GREEN}{self.BOLD}Server is running and ready for connections!{self.RESET}")
|
|
print(f"{self.BOLD}Waiting for clients...{self.RESET}\n")
|
|
sys.stdout.flush()
|
|
|
|
# Keep server running
|
|
await asyncio.Future()
|
|
|
|
|
|
def main():
|
|
"""Main entry point."""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="ASR Server with Live Display")
|
|
parser.add_argument("--host", default="0.0.0.0", help="Host address")
|
|
parser.add_argument("--port", type=int, default=8766, help="Port number")
|
|
parser.add_argument("--model-path", default="models/parakeet", help="Model directory")
|
|
parser.add_argument("--sample-rate", type=int, default=16000, help="Sample rate")
|
|
|
|
args = parser.parse_args()
|
|
|
|
server = DisplayServer(
|
|
host=args.host,
|
|
port=args.port,
|
|
model_path=args.model_path,
|
|
sample_rate=args.sample_rate,
|
|
)
|
|
|
|
try:
|
|
asyncio.run(server.start())
|
|
except KeyboardInterrupt:
|
|
print(f"\n\n{server.YELLOW}Server stopped by user{server.RESET}")
|
|
logger.info("Server stopped by user")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|