236 lines
7.8 KiB
Python
236 lines
7.8 KiB
Python
"""
|
|
Microphone streaming client for ASR WebSocket server
|
|
"""
|
|
import asyncio
|
|
import websockets
|
|
import sounddevice as sd
|
|
import numpy as np
|
|
import json
|
|
import logging
|
|
import queue
|
|
from typing import Optional
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MicrophoneStreamClient:
|
|
"""
|
|
Client for streaming microphone audio to ASR WebSocket server.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
server_url: str = "ws://localhost:8766",
|
|
sample_rate: int = 16000,
|
|
channels: int = 1,
|
|
chunk_duration: float = 0.1, # seconds
|
|
device: Optional[int] = None,
|
|
):
|
|
"""
|
|
Initialize microphone streaming client.
|
|
|
|
Args:
|
|
server_url: WebSocket server URL
|
|
sample_rate: Audio sample rate (16000 Hz recommended)
|
|
channels: Number of audio channels (1 for mono)
|
|
chunk_duration: Duration of each audio chunk in seconds
|
|
device: Optional audio input device index
|
|
"""
|
|
self.server_url = server_url
|
|
self.sample_rate = sample_rate
|
|
self.channels = channels
|
|
self.chunk_duration = chunk_duration
|
|
self.chunk_samples = int(sample_rate * chunk_duration)
|
|
self.device = device
|
|
|
|
self.audio_queue = queue.Queue()
|
|
self.is_recording = False
|
|
self.websocket = None
|
|
|
|
logger.info(f"Microphone client initialized")
|
|
logger.info(f"Server URL: {server_url}")
|
|
logger.info(f"Sample rate: {sample_rate} Hz")
|
|
logger.info(f"Chunk duration: {chunk_duration}s ({self.chunk_samples} samples)")
|
|
|
|
def audio_callback(self, indata, frames, time_info, status):
|
|
"""
|
|
Callback for sounddevice stream.
|
|
|
|
Args:
|
|
indata: Input audio data
|
|
frames: Number of frames
|
|
time_info: Timing information
|
|
status: Status flags
|
|
"""
|
|
if status:
|
|
logger.warning(f"Audio callback status: {status}")
|
|
|
|
# Convert to int16 and put in queue
|
|
audio_data = (indata[:, 0] * 32767).astype(np.int16)
|
|
self.audio_queue.put(audio_data.tobytes())
|
|
|
|
async def send_audio(self):
|
|
"""
|
|
Coroutine to send audio from queue to WebSocket.
|
|
"""
|
|
while self.is_recording:
|
|
try:
|
|
# Get audio data from queue (non-blocking)
|
|
audio_bytes = self.audio_queue.get_nowait()
|
|
|
|
if self.websocket:
|
|
await self.websocket.send(audio_bytes)
|
|
|
|
except queue.Empty:
|
|
# No audio data available, wait a bit
|
|
await asyncio.sleep(0.01)
|
|
except Exception as e:
|
|
logger.error(f"Error sending audio: {e}")
|
|
break
|
|
|
|
async def receive_transcripts(self):
|
|
"""
|
|
Coroutine to receive transcripts from WebSocket.
|
|
"""
|
|
while self.is_recording:
|
|
try:
|
|
if self.websocket:
|
|
message = await asyncio.wait_for(
|
|
self.websocket.recv(),
|
|
timeout=0.1
|
|
)
|
|
|
|
try:
|
|
data = json.loads(message)
|
|
|
|
if data.get("type") == "transcript":
|
|
text = data.get("text", "")
|
|
is_final = data.get("is_final", False)
|
|
|
|
if is_final:
|
|
logger.info(f"[FINAL] {text}")
|
|
else:
|
|
logger.info(f"[PARTIAL] {text}")
|
|
|
|
elif data.get("type") == "info":
|
|
logger.info(f"Server: {data.get('message')}")
|
|
|
|
elif data.get("type") == "error":
|
|
logger.error(f"Server error: {data.get('message')}")
|
|
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Invalid JSON response: {message}")
|
|
|
|
except asyncio.TimeoutError:
|
|
continue
|
|
except Exception as e:
|
|
logger.error(f"Error receiving transcript: {e}")
|
|
break
|
|
|
|
async def stream_audio(self):
|
|
"""
|
|
Main coroutine to stream audio to server.
|
|
"""
|
|
try:
|
|
async with websockets.connect(self.server_url) as websocket:
|
|
self.websocket = websocket
|
|
logger.info(f"Connected to server: {self.server_url}")
|
|
|
|
self.is_recording = True
|
|
|
|
# Start audio stream
|
|
with sd.InputStream(
|
|
samplerate=self.sample_rate,
|
|
channels=self.channels,
|
|
dtype=np.float32,
|
|
blocksize=self.chunk_samples,
|
|
device=self.device,
|
|
callback=self.audio_callback,
|
|
):
|
|
logger.info("Recording started. Press Ctrl+C to stop.")
|
|
|
|
# Run send and receive coroutines concurrently
|
|
await asyncio.gather(
|
|
self.send_audio(),
|
|
self.receive_transcripts(),
|
|
)
|
|
|
|
except websockets.exceptions.WebSocketException as e:
|
|
logger.error(f"WebSocket error: {e}")
|
|
except KeyboardInterrupt:
|
|
logger.info("Stopped by user")
|
|
finally:
|
|
self.is_recording = False
|
|
|
|
# Send final command
|
|
if self.websocket:
|
|
try:
|
|
await self.websocket.send(json.dumps({"type": "final"}))
|
|
await asyncio.sleep(0.5) # Wait for final response
|
|
except:
|
|
pass
|
|
|
|
self.websocket = None
|
|
logger.info("Disconnected from server")
|
|
|
|
def run(self):
|
|
"""
|
|
Run the client (blocking).
|
|
"""
|
|
try:
|
|
asyncio.run(self.stream_audio())
|
|
except KeyboardInterrupt:
|
|
logger.info("Client stopped by user")
|
|
|
|
|
|
def list_audio_devices():
|
|
"""
|
|
List available audio input devices.
|
|
"""
|
|
print("\nAvailable audio input devices:")
|
|
print("-" * 80)
|
|
devices = sd.query_devices()
|
|
for i, device in enumerate(devices):
|
|
if device['max_input_channels'] > 0:
|
|
print(f"[{i}] {device['name']}")
|
|
print(f" Channels: {device['max_input_channels']}")
|
|
print(f" Sample rate: {device['default_samplerate']} Hz")
|
|
print("-" * 80)
|
|
|
|
|
|
def main():
|
|
"""
|
|
Main entry point for the microphone client.
|
|
"""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Microphone Streaming Client")
|
|
parser.add_argument("--url", default="ws://localhost:8766", help="WebSocket server URL")
|
|
parser.add_argument("--sample-rate", type=int, default=16000, help="Audio sample rate")
|
|
parser.add_argument("--device", type=int, default=None, help="Audio input device index")
|
|
parser.add_argument("--list-devices", action="store_true", help="List audio devices and exit")
|
|
parser.add_argument("--chunk-duration", type=float, default=0.1, help="Audio chunk duration (seconds)")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.list_devices:
|
|
list_audio_devices()
|
|
return
|
|
|
|
client = MicrophoneStreamClient(
|
|
server_url=args.url,
|
|
sample_rate=args.sample_rate,
|
|
device=args.device,
|
|
chunk_duration=args.chunk_duration,
|
|
)
|
|
|
|
client.run()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|