Decided on Parakeet ONNX Runtime. Works pretty great. Realtime voice chat possible now. UX lacking.
This commit is contained in:
6
stt-parakeet/server/__init__.py
Normal file
6
stt-parakeet/server/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
WebSocket server module for streaming ASR
|
||||
"""
|
||||
from .ws_server import ASRWebSocketServer
|
||||
|
||||
__all__ = ["ASRWebSocketServer"]
|
||||
292
stt-parakeet/server/display_server.py
Normal file
292
stt-parakeet/server/display_server.py
Normal file
@@ -0,0 +1,292 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
ASR WebSocket Server with Live Transcription Display
|
||||
|
||||
This version displays transcriptions in real-time on the server console
|
||||
while clients stream audio from remote machines.
|
||||
"""
|
||||
import asyncio
|
||||
import websockets
|
||||
import numpy as np
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from asr.asr_pipeline import ASRPipeline
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('display_server.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DisplayServer:
|
||||
"""
|
||||
WebSocket server with live transcription display.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8766,
|
||||
model_path: str = "models/parakeet",
|
||||
sample_rate: int = 16000,
|
||||
):
|
||||
"""
|
||||
Initialize server.
|
||||
|
||||
Args:
|
||||
host: Host address to bind to
|
||||
port: Port to bind to
|
||||
model_path: Directory containing model files
|
||||
sample_rate: Audio sample rate
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.sample_rate = sample_rate
|
||||
self.active_connections = set()
|
||||
|
||||
# Terminal control codes
|
||||
self.CLEAR_LINE = '\033[2K'
|
||||
self.CURSOR_UP = '\033[1A'
|
||||
self.BOLD = '\033[1m'
|
||||
self.GREEN = '\033[92m'
|
||||
self.YELLOW = '\033[93m'
|
||||
self.BLUE = '\033[94m'
|
||||
self.RESET = '\033[0m'
|
||||
|
||||
# Initialize ASR pipeline
|
||||
logger.info("Loading ASR model...")
|
||||
self.pipeline = ASRPipeline(model_path=model_path)
|
||||
logger.info("ASR Pipeline ready")
|
||||
|
||||
# Client sessions
|
||||
self.sessions = {}
|
||||
|
||||
def print_header(self):
|
||||
"""Print server header."""
|
||||
print("\n" + "=" * 80)
|
||||
print(f"{self.BOLD}{self.BLUE}ASR Server - Live Transcription Display{self.RESET}")
|
||||
print("=" * 80)
|
||||
print(f"Server: ws://{self.host}:{self.port}")
|
||||
print(f"Sample Rate: {self.sample_rate} Hz")
|
||||
print(f"Model: Parakeet TDT 0.6B V3")
|
||||
print("=" * 80 + "\n")
|
||||
|
||||
def display_transcription(self, client_id: str, text: str, is_final: bool, is_progressive: bool = False):
|
||||
"""
|
||||
Display transcription in the terminal.
|
||||
|
||||
Args:
|
||||
client_id: Client identifier
|
||||
text: Transcribed text
|
||||
is_final: Whether this is the final transcription
|
||||
is_progressive: Whether this is a progressive update
|
||||
"""
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
|
||||
if is_final:
|
||||
# Final transcription - bold green
|
||||
print(f"{self.GREEN}{self.BOLD}[{timestamp}] {client_id}{self.RESET}")
|
||||
print(f"{self.GREEN} ✓ FINAL: {text}{self.RESET}\n")
|
||||
elif is_progressive:
|
||||
# Progressive update - yellow
|
||||
print(f"{self.YELLOW}[{timestamp}] {client_id}{self.RESET}")
|
||||
print(f"{self.YELLOW} → {text}{self.RESET}\n")
|
||||
else:
|
||||
# Regular transcription
|
||||
print(f"{self.BLUE}[{timestamp}] {client_id}{self.RESET}")
|
||||
print(f" {text}\n")
|
||||
|
||||
# Flush to ensure immediate display
|
||||
sys.stdout.flush()
|
||||
|
||||
async def handle_client(self, websocket):
|
||||
"""
|
||||
Handle individual WebSocket client connection.
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
"""
|
||||
client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}"
|
||||
logger.info(f"Client connected: {client_id}")
|
||||
self.active_connections.add(websocket)
|
||||
|
||||
# Display connection
|
||||
print(f"\n{self.BOLD}{'='*80}{self.RESET}")
|
||||
print(f"{self.GREEN}✓ Client connected: {client_id}{self.RESET}")
|
||||
print(f"{self.BOLD}{'='*80}{self.RESET}\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
# Audio buffer for accumulating ALL audio
|
||||
all_audio = []
|
||||
last_transcribed_samples = 0
|
||||
|
||||
# For progressive transcription
|
||||
min_chunk_duration = 2.0 # Minimum 2 seconds before transcribing
|
||||
min_chunk_samples = int(self.sample_rate * min_chunk_duration)
|
||||
|
||||
try:
|
||||
# Send welcome message
|
||||
await websocket.send(json.dumps({
|
||||
"type": "info",
|
||||
"message": "Connected to ASR server with live display",
|
||||
"sample_rate": self.sample_rate,
|
||||
}))
|
||||
|
||||
async for message in websocket:
|
||||
try:
|
||||
if isinstance(message, bytes):
|
||||
# Binary audio data
|
||||
audio_data = np.frombuffer(message, dtype=np.int16)
|
||||
audio_data = audio_data.astype(np.float32) / 32768.0
|
||||
|
||||
# Accumulate all audio
|
||||
all_audio.append(audio_data)
|
||||
total_samples = sum(len(chunk) for chunk in all_audio)
|
||||
|
||||
# Transcribe periodically when we have enough NEW audio
|
||||
samples_since_last = total_samples - last_transcribed_samples
|
||||
if samples_since_last >= min_chunk_samples:
|
||||
audio_chunk = np.concatenate(all_audio)
|
||||
last_transcribed_samples = total_samples
|
||||
|
||||
# Transcribe the accumulated audio
|
||||
try:
|
||||
text = self.pipeline.transcribe(
|
||||
audio_chunk,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
if text and text.strip():
|
||||
# Display on server
|
||||
self.display_transcription(client_id, text, is_final=False, is_progressive=True)
|
||||
|
||||
# Send to client
|
||||
response = {
|
||||
"type": "transcript",
|
||||
"text": text,
|
||||
"is_final": False,
|
||||
}
|
||||
await websocket.send(json.dumps(response))
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription error: {e}")
|
||||
await websocket.send(json.dumps({
|
||||
"type": "error",
|
||||
"message": f"Transcription failed: {str(e)}"
|
||||
}))
|
||||
|
||||
elif isinstance(message, str):
|
||||
# JSON command
|
||||
try:
|
||||
command = json.loads(message)
|
||||
|
||||
if command.get("type") == "final":
|
||||
# Process all accumulated audio (final transcription)
|
||||
if all_audio:
|
||||
audio_chunk = np.concatenate(all_audio)
|
||||
|
||||
text = self.pipeline.transcribe(
|
||||
audio_chunk,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
if text and text.strip():
|
||||
# Display on server
|
||||
self.display_transcription(client_id, text, is_final=True)
|
||||
|
||||
# Send to client
|
||||
response = {
|
||||
"type": "transcript",
|
||||
"text": text,
|
||||
"is_final": True,
|
||||
}
|
||||
await websocket.send(json.dumps(response))
|
||||
|
||||
# Clear buffer after final transcription
|
||||
all_audio = []
|
||||
last_transcribed_samples = 0
|
||||
|
||||
elif command.get("type") == "reset":
|
||||
# Reset buffer
|
||||
all_audio = []
|
||||
last_transcribed_samples = 0
|
||||
await websocket.send(json.dumps({
|
||||
"type": "info",
|
||||
"message": "Buffer reset"
|
||||
}))
|
||||
print(f"{self.YELLOW}[{client_id}] Buffer reset{self.RESET}\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON from {client_id}: {message}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message from {client_id}: {e}")
|
||||
break
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info(f"Connection closed: {client_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error with {client_id}: {e}")
|
||||
finally:
|
||||
self.active_connections.discard(websocket)
|
||||
print(f"\n{self.BOLD}{'='*80}{self.RESET}")
|
||||
print(f"{self.YELLOW}✗ Client disconnected: {client_id}{self.RESET}")
|
||||
print(f"{self.BOLD}{'='*80}{self.RESET}\n")
|
||||
sys.stdout.flush()
|
||||
logger.info(f"Connection closed: {client_id}")
|
||||
|
||||
async def start(self):
|
||||
"""Start the WebSocket server."""
|
||||
self.print_header()
|
||||
|
||||
async with websockets.serve(self.handle_client, self.host, self.port):
|
||||
logger.info(f"Starting WebSocket server on {self.host}:{self.port}")
|
||||
print(f"{self.GREEN}{self.BOLD}Server is running and ready for connections!{self.RESET}")
|
||||
print(f"{self.BOLD}Waiting for clients...{self.RESET}\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
# Keep server running
|
||||
await asyncio.Future()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="ASR Server with Live Display")
|
||||
parser.add_argument("--host", default="0.0.0.0", help="Host address")
|
||||
parser.add_argument("--port", type=int, default=8766, help="Port number")
|
||||
parser.add_argument("--model-path", default="models/parakeet", help="Model directory")
|
||||
parser.add_argument("--sample-rate", type=int, default=16000, help="Sample rate")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
server = DisplayServer(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
model_path=args.model_path,
|
||||
sample_rate=args.sample_rate,
|
||||
)
|
||||
|
||||
try:
|
||||
asyncio.run(server.start())
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n\n{server.YELLOW}Server stopped by user{server.RESET}")
|
||||
logger.info("Server stopped by user")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
416
stt-parakeet/server/vad_server.py
Normal file
416
stt-parakeet/server/vad_server.py
Normal file
@@ -0,0 +1,416 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
ASR WebSocket Server with VAD - Optimized for Discord Bots
|
||||
|
||||
This server uses Voice Activity Detection (VAD) to:
|
||||
- Detect speech start and end automatically
|
||||
- Only transcribe speech segments (ignore silence)
|
||||
- Provide clean boundaries for Discord message formatting
|
||||
- Minimize processing of silence/noise
|
||||
"""
|
||||
import asyncio
|
||||
import websockets
|
||||
import numpy as np
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from asr.asr_pipeline import ASRPipeline
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('vad_server.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeechSegment:
|
||||
"""Represents a segment of detected speech."""
|
||||
audio: np.ndarray
|
||||
start_time: float
|
||||
end_time: Optional[float] = None
|
||||
is_complete: bool = False
|
||||
|
||||
|
||||
class VADState:
|
||||
"""Manages VAD state for speech detection."""
|
||||
|
||||
def __init__(self, sample_rate: int = 16000, speech_threshold: float = 0.5):
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
# Simple energy-based VAD parameters
|
||||
self.energy_threshold = 0.005 # Lower threshold for better detection
|
||||
self.speech_frames = 0
|
||||
self.silence_frames = 0
|
||||
self.min_speech_frames = 3 # 3 frames minimum (300ms with 100ms chunks)
|
||||
self.min_silence_frames = 5 # 5 frames of silence (500ms)
|
||||
|
||||
self.is_speech = False
|
||||
self.speech_buffer = []
|
||||
|
||||
# Pre-buffer to capture audio BEFORE speech detection
|
||||
# This prevents cutting off the start of speech
|
||||
self.pre_buffer_frames = 5 # Keep 5 frames (500ms) of pre-speech audio
|
||||
self.pre_buffer = deque(maxlen=self.pre_buffer_frames)
|
||||
|
||||
# Progressive transcription tracking
|
||||
self.last_partial_samples = 0 # Track when we last sent a partial
|
||||
self.partial_interval_samples = int(sample_rate * 0.3) # Partial every 0.3 seconds (near real-time)
|
||||
|
||||
logger.info(f"VAD initialized: energy_threshold={self.energy_threshold}, pre_buffer={self.pre_buffer_frames} frames")
|
||||
|
||||
def calculate_energy(self, audio_chunk: np.ndarray) -> float:
|
||||
"""Calculate RMS energy of audio chunk."""
|
||||
return np.sqrt(np.mean(audio_chunk ** 2))
|
||||
|
||||
def process_audio(self, audio_chunk: np.ndarray) -> tuple[bool, Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""
|
||||
Process audio chunk and detect speech boundaries.
|
||||
|
||||
Returns:
|
||||
(speech_detected, complete_segment, partial_segment)
|
||||
- speech_detected: True if currently in speech
|
||||
- complete_segment: Audio segment if speech ended, None otherwise
|
||||
- partial_segment: Audio for partial transcription, None otherwise
|
||||
"""
|
||||
energy = self.calculate_energy(audio_chunk)
|
||||
chunk_is_speech = energy > self.energy_threshold
|
||||
|
||||
logger.debug(f"Energy: {energy:.6f}, Is speech: {chunk_is_speech}")
|
||||
|
||||
partial_segment = None
|
||||
|
||||
if chunk_is_speech:
|
||||
self.speech_frames += 1
|
||||
self.silence_frames = 0
|
||||
|
||||
if not self.is_speech and self.speech_frames >= self.min_speech_frames:
|
||||
# Speech started - add pre-buffer to capture the beginning!
|
||||
self.is_speech = True
|
||||
logger.info("🎤 Speech started (including pre-buffer)")
|
||||
|
||||
# Add pre-buffered audio to speech buffer
|
||||
if self.pre_buffer:
|
||||
logger.debug(f"Adding {len(self.pre_buffer)} pre-buffered frames")
|
||||
self.speech_buffer.extend(list(self.pre_buffer))
|
||||
self.pre_buffer.clear()
|
||||
|
||||
if self.is_speech:
|
||||
self.speech_buffer.append(audio_chunk)
|
||||
else:
|
||||
# Not in speech yet, keep in pre-buffer
|
||||
self.pre_buffer.append(audio_chunk)
|
||||
|
||||
# Check if we should send a partial transcription
|
||||
current_samples = sum(len(chunk) for chunk in self.speech_buffer)
|
||||
samples_since_last_partial = current_samples - self.last_partial_samples
|
||||
|
||||
# Send partial if enough NEW audio accumulated AND we have minimum duration
|
||||
min_duration_for_partial = int(self.sample_rate * 0.8) # At least 0.8s of audio
|
||||
if samples_since_last_partial >= self.partial_interval_samples and current_samples >= min_duration_for_partial:
|
||||
# Time for a partial update
|
||||
partial_segment = np.concatenate(self.speech_buffer)
|
||||
self.last_partial_samples = current_samples
|
||||
logger.debug(f"📝 Partial update: {current_samples/self.sample_rate:.2f}s")
|
||||
else:
|
||||
if self.is_speech:
|
||||
self.silence_frames += 1
|
||||
|
||||
# Add some trailing silence (up to limit)
|
||||
if self.silence_frames < self.min_silence_frames:
|
||||
self.speech_buffer.append(audio_chunk)
|
||||
else:
|
||||
# Speech ended
|
||||
logger.info(f"🛑 Speech ended after {self.silence_frames} silence frames")
|
||||
self.is_speech = False
|
||||
self.speech_frames = 0
|
||||
self.silence_frames = 0
|
||||
self.last_partial_samples = 0 # Reset partial counter
|
||||
|
||||
if self.speech_buffer:
|
||||
complete_segment = np.concatenate(self.speech_buffer)
|
||||
segment_duration = len(complete_segment) / self.sample_rate
|
||||
self.speech_buffer = []
|
||||
self.pre_buffer.clear() # Clear pre-buffer after speech ends
|
||||
logger.info(f"✅ Complete segment: {segment_duration:.2f}s")
|
||||
return False, complete_segment, None
|
||||
else:
|
||||
self.speech_frames = 0
|
||||
# Keep adding to pre-buffer when not in speech
|
||||
self.pre_buffer.append(audio_chunk)
|
||||
|
||||
return self.is_speech, None, partial_segment
|
||||
|
||||
|
||||
class VADServer:
|
||||
"""
|
||||
WebSocket server with VAD for Discord bot integration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8766,
|
||||
model_path: str = "models/parakeet",
|
||||
sample_rate: int = 16000,
|
||||
):
|
||||
"""Initialize server."""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.sample_rate = sample_rate
|
||||
self.active_connections = set()
|
||||
|
||||
# Terminal control codes
|
||||
self.BOLD = '\033[1m'
|
||||
self.GREEN = '\033[92m'
|
||||
self.YELLOW = '\033[93m'
|
||||
self.BLUE = '\033[94m'
|
||||
self.RED = '\033[91m'
|
||||
self.RESET = '\033[0m'
|
||||
|
||||
# Initialize ASR pipeline
|
||||
logger.info("Loading ASR model...")
|
||||
self.pipeline = ASRPipeline(model_path=model_path)
|
||||
logger.info("ASR Pipeline ready")
|
||||
|
||||
def print_header(self):
|
||||
"""Print server header."""
|
||||
print("\n" + "=" * 80)
|
||||
print(f"{self.BOLD}{self.BLUE}ASR Server with VAD - Discord Bot Ready{self.RESET}")
|
||||
print("=" * 80)
|
||||
print(f"Server: ws://{self.host}:{self.port}")
|
||||
print(f"Sample Rate: {self.sample_rate} Hz")
|
||||
print(f"Model: Parakeet TDT 0.6B V3")
|
||||
print(f"VAD: Energy-based speech detection")
|
||||
print("=" * 80 + "\n")
|
||||
|
||||
def display_transcription(self, client_id: str, text: str, duration: float):
|
||||
"""Display transcription in the terminal."""
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
print(f"{self.GREEN}{self.BOLD}[{timestamp}] {client_id}{self.RESET}")
|
||||
print(f"{self.GREEN} 📝 {text}{self.RESET}")
|
||||
print(f"{self.BLUE} ⏱️ Duration: {duration:.2f}s{self.RESET}\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
async def handle_client(self, websocket):
|
||||
"""Handle WebSocket client connection."""
|
||||
client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}"
|
||||
logger.info(f"Client connected: {client_id}")
|
||||
self.active_connections.add(websocket)
|
||||
|
||||
print(f"\n{self.BOLD}{'='*80}{self.RESET}")
|
||||
print(f"{self.GREEN}✓ Client connected: {client_id}{self.RESET}")
|
||||
print(f"{self.BOLD}{'='*80}{self.RESET}\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
# Initialize VAD state for this client
|
||||
vad_state = VADState(sample_rate=self.sample_rate)
|
||||
|
||||
try:
|
||||
# Send welcome message
|
||||
await websocket.send(json.dumps({
|
||||
"type": "info",
|
||||
"message": "Connected to ASR server with VAD",
|
||||
"sample_rate": self.sample_rate,
|
||||
"vad_enabled": True,
|
||||
}))
|
||||
|
||||
async for message in websocket:
|
||||
try:
|
||||
if isinstance(message, bytes):
|
||||
# Binary audio data
|
||||
audio_data = np.frombuffer(message, dtype=np.int16)
|
||||
audio_data = audio_data.astype(np.float32) / 32768.0
|
||||
|
||||
# Process through VAD
|
||||
is_speech, complete_segment, partial_segment = vad_state.process_audio(audio_data)
|
||||
|
||||
# Send VAD status to client (only on state change)
|
||||
prev_speech_state = getattr(vad_state, '_prev_speech_state', False)
|
||||
if is_speech != prev_speech_state:
|
||||
vad_state._prev_speech_state = is_speech
|
||||
await websocket.send(json.dumps({
|
||||
"type": "vad_status",
|
||||
"is_speech": is_speech,
|
||||
}))
|
||||
|
||||
# Handle partial transcription (progressive updates while speaking)
|
||||
if partial_segment is not None:
|
||||
try:
|
||||
text = self.pipeline.transcribe(
|
||||
partial_segment,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
if text and text.strip():
|
||||
duration = len(partial_segment) / self.sample_rate
|
||||
|
||||
# Display on server
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
print(f"{self.YELLOW}[{timestamp}] {client_id}{self.RESET}")
|
||||
print(f"{self.YELLOW} → PARTIAL: {text}{self.RESET}\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
# Send to client
|
||||
response = {
|
||||
"type": "transcript",
|
||||
"text": text,
|
||||
"is_final": False,
|
||||
"duration": duration,
|
||||
}
|
||||
await websocket.send(json.dumps(response))
|
||||
except Exception as e:
|
||||
logger.error(f"Partial transcription error: {e}")
|
||||
|
||||
# If we have a complete speech segment, transcribe it
|
||||
if complete_segment is not None:
|
||||
try:
|
||||
text = self.pipeline.transcribe(
|
||||
complete_segment,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
if text and text.strip():
|
||||
duration = len(complete_segment) / self.sample_rate
|
||||
|
||||
# Display on server
|
||||
self.display_transcription(client_id, text, duration)
|
||||
|
||||
# Send to client
|
||||
response = {
|
||||
"type": "transcript",
|
||||
"text": text,
|
||||
"is_final": True,
|
||||
"duration": duration,
|
||||
}
|
||||
await websocket.send(json.dumps(response))
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription error: {e}")
|
||||
await websocket.send(json.dumps({
|
||||
"type": "error",
|
||||
"message": f"Transcription failed: {str(e)}"
|
||||
}))
|
||||
|
||||
elif isinstance(message, str):
|
||||
# JSON command
|
||||
try:
|
||||
command = json.loads(message)
|
||||
|
||||
if command.get("type") == "force_transcribe":
|
||||
# Force transcribe current buffer
|
||||
if vad_state.speech_buffer:
|
||||
audio_chunk = np.concatenate(vad_state.speech_buffer)
|
||||
vad_state.speech_buffer = []
|
||||
vad_state.is_speech = False
|
||||
|
||||
text = self.pipeline.transcribe(
|
||||
audio_chunk,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
if text and text.strip():
|
||||
duration = len(audio_chunk) / self.sample_rate
|
||||
self.display_transcription(client_id, text, duration)
|
||||
|
||||
response = {
|
||||
"type": "transcript",
|
||||
"text": text,
|
||||
"is_final": True,
|
||||
"duration": duration,
|
||||
}
|
||||
await websocket.send(json.dumps(response))
|
||||
|
||||
elif command.get("type") == "reset":
|
||||
# Reset VAD state
|
||||
vad_state = VADState(sample_rate=self.sample_rate)
|
||||
await websocket.send(json.dumps({
|
||||
"type": "info",
|
||||
"message": "VAD state reset"
|
||||
}))
|
||||
print(f"{self.YELLOW}[{client_id}] VAD reset{self.RESET}\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
elif command.get("type") == "set_threshold":
|
||||
# Adjust VAD threshold
|
||||
threshold = command.get("threshold", 0.01)
|
||||
vad_state.energy_threshold = threshold
|
||||
await websocket.send(json.dumps({
|
||||
"type": "info",
|
||||
"message": f"VAD threshold set to {threshold}"
|
||||
}))
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON from {client_id}: {message}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message from {client_id}: {e}")
|
||||
break
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info(f"Connection closed: {client_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error with {client_id}: {e}")
|
||||
finally:
|
||||
self.active_connections.discard(websocket)
|
||||
print(f"\n{self.BOLD}{'='*80}{self.RESET}")
|
||||
print(f"{self.YELLOW}✗ Client disconnected: {client_id}{self.RESET}")
|
||||
print(f"{self.BOLD}{'='*80}{self.RESET}\n")
|
||||
sys.stdout.flush()
|
||||
logger.info(f"Connection closed: {client_id}")
|
||||
|
||||
async def start(self):
|
||||
"""Start the WebSocket server."""
|
||||
self.print_header()
|
||||
|
||||
async with websockets.serve(self.handle_client, self.host, self.port):
|
||||
logger.info(f"Starting WebSocket server on {self.host}:{self.port}")
|
||||
print(f"{self.GREEN}{self.BOLD}Server is running with VAD enabled!{self.RESET}")
|
||||
print(f"{self.BOLD}Ready for Discord bot connections...{self.RESET}\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
# Keep server running
|
||||
await asyncio.Future()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="ASR Server with VAD for Discord")
|
||||
parser.add_argument("--host", default="0.0.0.0", help="Host address")
|
||||
parser.add_argument("--port", type=int, default=8766, help="Port number")
|
||||
parser.add_argument("--model-path", default="models/parakeet", help="Model directory")
|
||||
parser.add_argument("--sample-rate", type=int, default=16000, help="Sample rate")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
server = VADServer(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
model_path=args.model_path,
|
||||
sample_rate=args.sample_rate,
|
||||
)
|
||||
|
||||
try:
|
||||
asyncio.run(server.start())
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n\n{server.YELLOW}Server stopped by user{server.RESET}")
|
||||
logger.info("Server stopped by user")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
231
stt-parakeet/server/ws_server.py
Normal file
231
stt-parakeet/server/ws_server.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
WebSocket server for streaming ASR using onnx-asr
|
||||
"""
|
||||
import asyncio
|
||||
import websockets
|
||||
import numpy as np
|
||||
import json
|
||||
import logging
|
||||
from asr.asr_pipeline import ASRPipeline
|
||||
from typing import Optional
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ASRWebSocketServer:
|
||||
"""
|
||||
WebSocket server for real-time speech recognition.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8766,
|
||||
model_name: str = "nemo-parakeet-tdt-0.6b-v3",
|
||||
model_path: Optional[str] = None,
|
||||
use_vad: bool = False,
|
||||
sample_rate: int = 16000,
|
||||
):
|
||||
"""
|
||||
Initialize WebSocket server.
|
||||
|
||||
Args:
|
||||
host: Server host address
|
||||
port: Server port
|
||||
model_name: ASR model name
|
||||
model_path: Optional local model path
|
||||
use_vad: Whether to use VAD
|
||||
sample_rate: Expected audio sample rate
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
logger.info("Initializing ASR Pipeline...")
|
||||
self.pipeline = ASRPipeline(
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
use_vad=use_vad,
|
||||
)
|
||||
logger.info("ASR Pipeline ready")
|
||||
|
||||
self.active_connections = set()
|
||||
|
||||
async def handle_client(self, websocket):
|
||||
"""
|
||||
Handle individual WebSocket client connection.
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
"""
|
||||
client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}"
|
||||
logger.info(f"Client connected: {client_id}")
|
||||
self.active_connections.add(websocket)
|
||||
|
||||
# Audio buffer for accumulating ALL audio
|
||||
all_audio = []
|
||||
last_transcribed_samples = 0 # Track what we've already transcribed
|
||||
|
||||
# For progressive transcription, we'll accumulate and transcribe the full buffer
|
||||
# This gives better results than processing tiny chunks
|
||||
min_chunk_duration = 2.0 # Minimum 2 seconds before transcribing
|
||||
min_chunk_samples = int(self.sample_rate * min_chunk_duration)
|
||||
|
||||
try:
|
||||
# Send welcome message
|
||||
await websocket.send(json.dumps({
|
||||
"type": "info",
|
||||
"message": "Connected to ASR server",
|
||||
"sample_rate": self.sample_rate,
|
||||
}))
|
||||
|
||||
async for message in websocket:
|
||||
try:
|
||||
if isinstance(message, bytes):
|
||||
# Binary audio data
|
||||
# Convert bytes to float32 numpy array
|
||||
# Assuming int16 PCM data
|
||||
audio_data = np.frombuffer(message, dtype=np.int16)
|
||||
audio_data = audio_data.astype(np.float32) / 32768.0
|
||||
|
||||
# Accumulate all audio
|
||||
all_audio.append(audio_data)
|
||||
total_samples = sum(len(chunk) for chunk in all_audio)
|
||||
|
||||
# Transcribe periodically when we have enough NEW audio
|
||||
samples_since_last = total_samples - last_transcribed_samples
|
||||
if samples_since_last >= min_chunk_samples:
|
||||
audio_chunk = np.concatenate(all_audio)
|
||||
last_transcribed_samples = total_samples
|
||||
|
||||
# Transcribe the accumulated audio
|
||||
try:
|
||||
text = self.pipeline.transcribe(
|
||||
audio_chunk,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
if text and text.strip():
|
||||
response = {
|
||||
"type": "transcript",
|
||||
"text": text,
|
||||
"is_final": False,
|
||||
}
|
||||
await websocket.send(json.dumps(response))
|
||||
logger.info(f"Progressive transcription: {text}")
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription error: {e}")
|
||||
await websocket.send(json.dumps({
|
||||
"type": "error",
|
||||
"message": f"Transcription failed: {str(e)}"
|
||||
}))
|
||||
|
||||
elif isinstance(message, str):
|
||||
# JSON command
|
||||
try:
|
||||
command = json.loads(message)
|
||||
|
||||
if command.get("type") == "final":
|
||||
# Process all accumulated audio (final transcription)
|
||||
if all_audio:
|
||||
audio_chunk = np.concatenate(all_audio)
|
||||
|
||||
text = self.pipeline.transcribe(
|
||||
audio_chunk,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
if text and text.strip():
|
||||
response = {
|
||||
"type": "transcript",
|
||||
"text": text,
|
||||
"is_final": True,
|
||||
}
|
||||
await websocket.send(json.dumps(response))
|
||||
logger.info(f"Final transcription: {text}")
|
||||
|
||||
# Clear buffer after final transcription
|
||||
all_audio = []
|
||||
last_transcribed_samples = 0
|
||||
|
||||
elif command.get("type") == "reset":
|
||||
# Reset buffer
|
||||
all_audio = []
|
||||
last_transcribed_samples = 0
|
||||
await websocket.send(json.dumps({
|
||||
"type": "info",
|
||||
"message": "Buffer reset"
|
||||
}))
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON command: {message}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
await websocket.send(json.dumps({
|
||||
"type": "error",
|
||||
"message": str(e)
|
||||
}))
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info(f"Client disconnected: {client_id}")
|
||||
|
||||
finally:
|
||||
self.active_connections.discard(websocket)
|
||||
logger.info(f"Connection closed: {client_id}")
|
||||
|
||||
async def start(self):
|
||||
"""
|
||||
Start the WebSocket server.
|
||||
"""
|
||||
logger.info(f"Starting WebSocket server on {self.host}:{self.port}")
|
||||
|
||||
async with websockets.serve(self.handle_client, self.host, self.port):
|
||||
logger.info(f"Server running on ws://{self.host}:{self.port}")
|
||||
logger.info(f"Active connections: {len(self.active_connections)}")
|
||||
await asyncio.Future() # Run forever
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Run the server (blocking).
|
||||
"""
|
||||
try:
|
||||
asyncio.run(self.start())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Server stopped by user")
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main entry point for the WebSocket server.
|
||||
"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="ASR WebSocket Server")
|
||||
parser.add_argument("--host", default="0.0.0.0", help="Server host")
|
||||
parser.add_argument("--port", type=int, default=8766, help="Server port")
|
||||
parser.add_argument("--model", default="nemo-parakeet-tdt-0.6b-v3", help="Model name")
|
||||
parser.add_argument("--model-path", default=None, help="Local model path")
|
||||
parser.add_argument("--use-vad", action="store_true", help="Enable VAD")
|
||||
parser.add_argument("--sample-rate", type=int, default=16000, help="Audio sample rate")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
server = ASRWebSocketServer(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
model_name=args.model,
|
||||
model_path=args.model_path,
|
||||
use_vad=args.use_vad,
|
||||
sample_rate=args.sample_rate,
|
||||
)
|
||||
|
||||
server.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user