Phase 4 STT pipeline implemented — Silero VAD + faster-whisper — still not working well at all
This commit is contained in:
206
stt/test_stt.py
Normal file
206
stt/test_stt.py
Normal file
@@ -0,0 +1,206 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for STT WebSocket server.
|
||||
Sends test audio and receives VAD/transcription events.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import websockets
|
||||
import numpy as np
|
||||
import json
|
||||
import wave
|
||||
|
||||
|
||||
async def test_websocket():
|
||||
"""Test STT WebSocket with generated audio."""
|
||||
|
||||
uri = "ws://localhost:8001/ws/stt/test_user"
|
||||
|
||||
print("🔌 Connecting to STT WebSocket...")
|
||||
|
||||
async with websockets.connect(uri) as websocket:
|
||||
# Wait for ready message
|
||||
ready_msg = await websocket.recv()
|
||||
ready = json.loads(ready_msg)
|
||||
print(f"✅ {ready}")
|
||||
|
||||
# Generate test audio: 2 seconds of 440Hz tone (A note)
|
||||
# This simulates speech-like audio
|
||||
print("\n🎵 Generating test audio (2 seconds, 440Hz tone)...")
|
||||
sample_rate = 16000
|
||||
duration = 2.0
|
||||
frequency = 440 # A4 note
|
||||
|
||||
t = np.linspace(0, duration, int(sample_rate * duration), False)
|
||||
audio = np.sin(frequency * 2 * np.pi * t)
|
||||
|
||||
# Convert to int16
|
||||
audio_int16 = (audio * 32767).astype(np.int16)
|
||||
|
||||
# Send in 20ms chunks (320 samples at 16kHz)
|
||||
chunk_size = 320 # 20ms chunks
|
||||
total_chunks = len(audio_int16) // chunk_size
|
||||
|
||||
print(f"📤 Sending {total_chunks} audio chunks (20ms each)...\n")
|
||||
|
||||
# Send chunks and receive events
|
||||
for i in range(0, len(audio_int16), chunk_size):
|
||||
chunk = audio_int16[i:i+chunk_size]
|
||||
|
||||
# Send audio chunk
|
||||
await websocket.send(chunk.tobytes())
|
||||
|
||||
# Try to receive events (non-blocking)
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
websocket.recv(),
|
||||
timeout=0.01
|
||||
)
|
||||
event = json.loads(response)
|
||||
|
||||
# Print VAD events
|
||||
if event['type'] == 'vad':
|
||||
emoji = "🟢" if event['speaking'] else "⚪"
|
||||
print(f"{emoji} VAD: {event['event']} "
|
||||
f"(prob={event['probability']:.3f}, "
|
||||
f"t={event['timestamp']:.1f}ms)")
|
||||
|
||||
# Print transcription events
|
||||
elif event['type'] == 'partial':
|
||||
print(f"📝 Partial: \"{event['text']}\"")
|
||||
|
||||
elif event['type'] == 'final':
|
||||
print(f"✅ Final: \"{event['text']}\"")
|
||||
|
||||
elif event['type'] == 'interruption':
|
||||
print(f"⚠️ Interruption detected! (prob={event['probability']:.3f})")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
pass # No event yet
|
||||
|
||||
# Small delay between chunks
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
print("\n✅ Test audio sent successfully!")
|
||||
|
||||
# Wait a bit for final transcription
|
||||
print("⏳ Waiting for final transcription...")
|
||||
|
||||
for _ in range(50): # Wait up to 1 second
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
websocket.recv(),
|
||||
timeout=0.02
|
||||
)
|
||||
event = json.loads(response)
|
||||
|
||||
if event['type'] == 'final':
|
||||
print(f"\n✅ FINAL TRANSCRIPT: \"{event['text']}\"")
|
||||
break
|
||||
elif event['type'] == 'vad':
|
||||
emoji = "🟢" if event['speaking'] else "⚪"
|
||||
print(f"{emoji} VAD: {event['event']} (prob={event['probability']:.3f})")
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
print("\n✅ WebSocket test complete!")
|
||||
|
||||
|
||||
async def test_with_sample_audio():
|
||||
"""Test with actual speech audio file (if available)."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
if len(sys.argv) > 1 and os.path.exists(sys.argv[1]):
|
||||
audio_file = sys.argv[1]
|
||||
print(f"📂 Loading audio from: {audio_file}")
|
||||
|
||||
# Load WAV file
|
||||
with wave.open(audio_file, 'rb') as wav:
|
||||
sample_rate = wav.getframerate()
|
||||
n_channels = wav.getnchannels()
|
||||
audio_data = wav.readframes(wav.getnframes())
|
||||
|
||||
# Convert to numpy array
|
||||
audio_np = np.frombuffer(audio_data, dtype=np.int16)
|
||||
|
||||
# If stereo, convert to mono
|
||||
if n_channels == 2:
|
||||
audio_np = audio_np.reshape(-1, 2).mean(axis=1).astype(np.int16)
|
||||
|
||||
# Resample to 16kHz if needed
|
||||
if sample_rate != 16000:
|
||||
print(f"⚠️ Resampling from {sample_rate}Hz to 16000Hz...")
|
||||
import librosa
|
||||
audio_float = audio_np.astype(np.float32) / 32768.0
|
||||
audio_resampled = librosa.resample(
|
||||
audio_float,
|
||||
orig_sr=sample_rate,
|
||||
target_sr=16000
|
||||
)
|
||||
audio_np = (audio_resampled * 32767).astype(np.int16)
|
||||
|
||||
print(f"✅ Audio loaded: {len(audio_np)/16000:.2f} seconds")
|
||||
|
||||
# Send to STT
|
||||
uri = "ws://localhost:8001/ws/stt/test_user"
|
||||
|
||||
async with websockets.connect(uri) as websocket:
|
||||
ready_msg = await websocket.recv()
|
||||
print(f"✅ {json.loads(ready_msg)}")
|
||||
|
||||
# Send in chunks
|
||||
chunk_size = 320 # 20ms at 16kHz
|
||||
|
||||
for i in range(0, len(audio_np), chunk_size):
|
||||
chunk = audio_np[i:i+chunk_size]
|
||||
await websocket.send(chunk.tobytes())
|
||||
|
||||
# Receive events
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
websocket.recv(),
|
||||
timeout=0.01
|
||||
)
|
||||
event = json.loads(response)
|
||||
|
||||
if event['type'] == 'vad':
|
||||
emoji = "🟢" if event['speaking'] else "⚪"
|
||||
print(f"{emoji} VAD: {event['event']} (prob={event['probability']:.3f})")
|
||||
elif event['type'] in ['partial', 'final']:
|
||||
print(f"📝 {event['type'].title()}: \"{event['text']}\"")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
# Wait for final
|
||||
for _ in range(100):
|
||||
try:
|
||||
response = await asyncio.wait_for(websocket.recv(), timeout=0.02)
|
||||
event = json.loads(response)
|
||||
if event['type'] == 'final':
|
||||
print(f"\n✅ FINAL: \"{event['text']}\"")
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
print("=" * 60)
|
||||
print(" Miku STT WebSocket Test")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
print("📁 Testing with audio file...")
|
||||
asyncio.run(test_with_sample_audio())
|
||||
else:
|
||||
print("🎵 Testing with generated tone...")
|
||||
print(" (To test with audio file: python test_stt.py audio.wav)")
|
||||
print()
|
||||
asyncio.run(test_websocket())
|
||||
Reference in New Issue
Block a user