Decided on Parakeet ONNX Runtime. Works pretty great. Realtime voice chat possible now. UX lacking.
This commit is contained in:
231
stt-parakeet/server/ws_server.py
Normal file
231
stt-parakeet/server/ws_server.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
WebSocket server for streaming ASR using onnx-asr
|
||||
"""
|
||||
import asyncio
|
||||
import websockets
|
||||
import numpy as np
|
||||
import json
|
||||
import logging
|
||||
from asr.asr_pipeline import ASRPipeline
|
||||
from typing import Optional
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ASRWebSocketServer:
|
||||
"""
|
||||
WebSocket server for real-time speech recognition.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8766,
|
||||
model_name: str = "nemo-parakeet-tdt-0.6b-v3",
|
||||
model_path: Optional[str] = None,
|
||||
use_vad: bool = False,
|
||||
sample_rate: int = 16000,
|
||||
):
|
||||
"""
|
||||
Initialize WebSocket server.
|
||||
|
||||
Args:
|
||||
host: Server host address
|
||||
port: Server port
|
||||
model_name: ASR model name
|
||||
model_path: Optional local model path
|
||||
use_vad: Whether to use VAD
|
||||
sample_rate: Expected audio sample rate
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
logger.info("Initializing ASR Pipeline...")
|
||||
self.pipeline = ASRPipeline(
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
use_vad=use_vad,
|
||||
)
|
||||
logger.info("ASR Pipeline ready")
|
||||
|
||||
self.active_connections = set()
|
||||
|
||||
async def handle_client(self, websocket):
|
||||
"""
|
||||
Handle individual WebSocket client connection.
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
"""
|
||||
client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}"
|
||||
logger.info(f"Client connected: {client_id}")
|
||||
self.active_connections.add(websocket)
|
||||
|
||||
# Audio buffer for accumulating ALL audio
|
||||
all_audio = []
|
||||
last_transcribed_samples = 0 # Track what we've already transcribed
|
||||
|
||||
# For progressive transcription, we'll accumulate and transcribe the full buffer
|
||||
# This gives better results than processing tiny chunks
|
||||
min_chunk_duration = 2.0 # Minimum 2 seconds before transcribing
|
||||
min_chunk_samples = int(self.sample_rate * min_chunk_duration)
|
||||
|
||||
try:
|
||||
# Send welcome message
|
||||
await websocket.send(json.dumps({
|
||||
"type": "info",
|
||||
"message": "Connected to ASR server",
|
||||
"sample_rate": self.sample_rate,
|
||||
}))
|
||||
|
||||
async for message in websocket:
|
||||
try:
|
||||
if isinstance(message, bytes):
|
||||
# Binary audio data
|
||||
# Convert bytes to float32 numpy array
|
||||
# Assuming int16 PCM data
|
||||
audio_data = np.frombuffer(message, dtype=np.int16)
|
||||
audio_data = audio_data.astype(np.float32) / 32768.0
|
||||
|
||||
# Accumulate all audio
|
||||
all_audio.append(audio_data)
|
||||
total_samples = sum(len(chunk) for chunk in all_audio)
|
||||
|
||||
# Transcribe periodically when we have enough NEW audio
|
||||
samples_since_last = total_samples - last_transcribed_samples
|
||||
if samples_since_last >= min_chunk_samples:
|
||||
audio_chunk = np.concatenate(all_audio)
|
||||
last_transcribed_samples = total_samples
|
||||
|
||||
# Transcribe the accumulated audio
|
||||
try:
|
||||
text = self.pipeline.transcribe(
|
||||
audio_chunk,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
if text and text.strip():
|
||||
response = {
|
||||
"type": "transcript",
|
||||
"text": text,
|
||||
"is_final": False,
|
||||
}
|
||||
await websocket.send(json.dumps(response))
|
||||
logger.info(f"Progressive transcription: {text}")
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription error: {e}")
|
||||
await websocket.send(json.dumps({
|
||||
"type": "error",
|
||||
"message": f"Transcription failed: {str(e)}"
|
||||
}))
|
||||
|
||||
elif isinstance(message, str):
|
||||
# JSON command
|
||||
try:
|
||||
command = json.loads(message)
|
||||
|
||||
if command.get("type") == "final":
|
||||
# Process all accumulated audio (final transcription)
|
||||
if all_audio:
|
||||
audio_chunk = np.concatenate(all_audio)
|
||||
|
||||
text = self.pipeline.transcribe(
|
||||
audio_chunk,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
if text and text.strip():
|
||||
response = {
|
||||
"type": "transcript",
|
||||
"text": text,
|
||||
"is_final": True,
|
||||
}
|
||||
await websocket.send(json.dumps(response))
|
||||
logger.info(f"Final transcription: {text}")
|
||||
|
||||
# Clear buffer after final transcription
|
||||
all_audio = []
|
||||
last_transcribed_samples = 0
|
||||
|
||||
elif command.get("type") == "reset":
|
||||
# Reset buffer
|
||||
all_audio = []
|
||||
last_transcribed_samples = 0
|
||||
await websocket.send(json.dumps({
|
||||
"type": "info",
|
||||
"message": "Buffer reset"
|
||||
}))
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON command: {message}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
await websocket.send(json.dumps({
|
||||
"type": "error",
|
||||
"message": str(e)
|
||||
}))
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info(f"Client disconnected: {client_id}")
|
||||
|
||||
finally:
|
||||
self.active_connections.discard(websocket)
|
||||
logger.info(f"Connection closed: {client_id}")
|
||||
|
||||
async def start(self):
|
||||
"""
|
||||
Start the WebSocket server.
|
||||
"""
|
||||
logger.info(f"Starting WebSocket server on {self.host}:{self.port}")
|
||||
|
||||
async with websockets.serve(self.handle_client, self.host, self.port):
|
||||
logger.info(f"Server running on ws://{self.host}:{self.port}")
|
||||
logger.info(f"Active connections: {len(self.active_connections)}")
|
||||
await asyncio.Future() # Run forever
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Run the server (blocking).
|
||||
"""
|
||||
try:
|
||||
asyncio.run(self.start())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Server stopped by user")
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main entry point for the WebSocket server.
|
||||
"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="ASR WebSocket Server")
|
||||
parser.add_argument("--host", default="0.0.0.0", help="Server host")
|
||||
parser.add_argument("--port", type=int, default=8766, help="Server port")
|
||||
parser.add_argument("--model", default="nemo-parakeet-tdt-0.6b-v3", help="Model name")
|
||||
parser.add_argument("--model-path", default=None, help="Local model path")
|
||||
parser.add_argument("--use-vad", action="store_true", help="Enable VAD")
|
||||
parser.add_argument("--sample-rate", type=int, default=16000, help="Audio sample rate")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
server = ASRWebSocketServer(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
model_name=args.model,
|
||||
model_path=args.model_path,
|
||||
use_vad=args.use_vad,
|
||||
sample_rate=args.sample_rate,
|
||||
)
|
||||
|
||||
server.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user