Cleanup. Moved prototype and testing STT/TTS to 1TB HDD
This commit is contained in:
42
stt-parakeet/.gitignore
vendored
42
stt-parakeet/.gitignore
vendored
@@ -1,42 +0,0 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
|
||||
# IDEs
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Models
|
||||
models/
|
||||
*.onnx
|
||||
|
||||
# Audio files
|
||||
*.wav
|
||||
*.mp3
|
||||
*.flac
|
||||
*.ogg
|
||||
test_audio/
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
log
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Temporary files
|
||||
*.tmp
|
||||
*.temp
|
||||
@@ -1,303 +0,0 @@
|
||||
# Server & Client Usage Guide
|
||||
|
||||
## ✅ Server is Working!
|
||||
|
||||
The WebSocket server is running on port **8766** with GPU acceleration.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Start the Server
|
||||
|
||||
```bash
|
||||
./run.sh server/ws_server.py
|
||||
```
|
||||
|
||||
Server will start on: `ws://localhost:8766`
|
||||
|
||||
### 2. Test with Simple Client
|
||||
|
||||
```bash
|
||||
./run.sh test_client.py test.wav
|
||||
```
|
||||
|
||||
### 3. Use Microphone Client
|
||||
|
||||
```bash
|
||||
# List audio devices first
|
||||
./run.sh client/mic_stream.py --list-devices
|
||||
|
||||
# Start streaming from microphone
|
||||
./run.sh client/mic_stream.py
|
||||
|
||||
# Or specify device
|
||||
./run.sh client/mic_stream.py --device 0
|
||||
```
|
||||
|
||||
## Available Clients
|
||||
|
||||
### 1. **test_client.py** - Simple File Testing
|
||||
```bash
|
||||
./run.sh test_client.py your_audio.wav
|
||||
```
|
||||
- Sends audio file to server
|
||||
- Shows real-time transcription
|
||||
- Good for testing
|
||||
|
||||
### 2. **client/mic_stream.py** - Live Microphone
|
||||
```bash
|
||||
./run.sh client/mic_stream.py
|
||||
```
|
||||
- Captures from microphone
|
||||
- Streams to server
|
||||
- Real-time transcription display
|
||||
|
||||
### 3. **Custom Client** - Your Own Script
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import websockets
|
||||
import json
|
||||
|
||||
async def connect():
|
||||
async with websockets.connect("ws://localhost:8766") as ws:
|
||||
# Send audio as int16 PCM bytes
|
||||
audio_bytes = your_audio_data.astype('int16').tobytes()
|
||||
await ws.send(audio_bytes)
|
||||
|
||||
# Receive transcription
|
||||
response = await ws.recv()
|
||||
result = json.loads(response)
|
||||
print(result['text'])
|
||||
|
||||
asyncio.run(connect())
|
||||
```
|
||||
|
||||
## Server Options
|
||||
|
||||
```bash
|
||||
# Custom host/port
|
||||
./run.sh server/ws_server.py --host 0.0.0.0 --port 9000
|
||||
|
||||
# Enable VAD (for long audio)
|
||||
./run.sh server/ws_server.py --use-vad
|
||||
|
||||
# Different model
|
||||
./run.sh server/ws_server.py --model nemo-parakeet-tdt-0.6b-v3
|
||||
|
||||
# Change sample rate
|
||||
./run.sh server/ws_server.py --sample-rate 16000
|
||||
```
|
||||
|
||||
## Client Options
|
||||
|
||||
### Microphone Client
|
||||
```bash
|
||||
# List devices
|
||||
./run.sh client/mic_stream.py --list-devices
|
||||
|
||||
# Use specific device
|
||||
./run.sh client/mic_stream.py --device 2
|
||||
|
||||
# Custom server URL
|
||||
./run.sh client/mic_stream.py --url ws://192.168.1.100:8766
|
||||
|
||||
# Adjust chunk duration (lower = lower latency)
|
||||
./run.sh client/mic_stream.py --chunk-duration 0.05
|
||||
```
|
||||
|
||||
## Protocol
|
||||
|
||||
The server uses a simple JSON-based protocol:
|
||||
|
||||
### Server → Client Messages
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "info",
|
||||
"message": "Connected to ASR server",
|
||||
"sample_rate": 16000
|
||||
}
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": "transcribed text here",
|
||||
"is_final": false
|
||||
}
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "error",
|
||||
"message": "error description"
|
||||
}
|
||||
```
|
||||
|
||||
### Client → Server Messages
|
||||
|
||||
**Send audio:**
|
||||
- Binary data (int16 PCM, little-endian)
|
||||
- Sample rate: 16000 Hz
|
||||
- Mono channel
|
||||
|
||||
**Send commands:**
|
||||
```json
|
||||
{"type": "final"} // Process remaining buffer
|
||||
{"type": "reset"} // Reset audio buffer
|
||||
```
|
||||
|
||||
## Audio Format Requirements
|
||||
|
||||
- **Format**: int16 PCM (bytes)
|
||||
- **Sample Rate**: 16000 Hz
|
||||
- **Channels**: Mono (1)
|
||||
- **Byte Order**: Little-endian
|
||||
|
||||
### Convert Audio in Python
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
|
||||
# Load audio
|
||||
audio, sr = sf.read("file.wav", dtype='float32')
|
||||
|
||||
# Convert to mono
|
||||
if audio.ndim > 1:
|
||||
audio = audio[:, 0]
|
||||
|
||||
# Resample if needed (install resampy)
|
||||
if sr != 16000:
|
||||
import resampy
|
||||
audio = resampy.resample(audio, sr, 16000)
|
||||
|
||||
# Convert to int16 for sending
|
||||
audio_int16 = (audio * 32767).astype(np.int16)
|
||||
audio_bytes = audio_int16.tobytes()
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
### Browser Client (JavaScript)
|
||||
|
||||
```javascript
|
||||
const ws = new WebSocket('ws://localhost:8766');
|
||||
|
||||
ws.onopen = () => {
|
||||
console.log('Connected!');
|
||||
|
||||
// Capture from microphone
|
||||
navigator.mediaDevices.getUserMedia({ audio: true })
|
||||
.then(stream => {
|
||||
const audioContext = new AudioContext({ sampleRate: 16000 });
|
||||
const source = audioContext.createMediaStreamSource(stream);
|
||||
const processor = audioContext.createScriptProcessor(4096, 1, 1);
|
||||
|
||||
processor.onaudioprocess = (e) => {
|
||||
const audioData = e.inputBuffer.getChannelData(0);
|
||||
// Convert float32 to int16
|
||||
const int16Data = new Int16Array(audioData.length);
|
||||
for (let i = 0; i < audioData.length; i++) {
|
||||
int16Data[i] = Math.max(-32768, Math.min(32767, audioData[i] * 32768));
|
||||
}
|
||||
ws.send(int16Data.buffer);
|
||||
};
|
||||
|
||||
source.connect(processor);
|
||||
processor.connect(audioContext.destination);
|
||||
});
|
||||
};
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
if (data.type === 'transcript') {
|
||||
console.log('Transcription:', data.text);
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
### Python Script Client
|
||||
|
||||
```python
|
||||
#!/usr/bin/env python3
|
||||
import asyncio
|
||||
import websockets
|
||||
import sounddevice as sd
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
async def stream_microphone():
|
||||
uri = "ws://localhost:8766"
|
||||
|
||||
async with websockets.connect(uri) as ws:
|
||||
print("Connected!")
|
||||
|
||||
def audio_callback(indata, frames, time, status):
|
||||
# Convert to int16 and send
|
||||
audio = (indata[:, 0] * 32767).astype(np.int16)
|
||||
asyncio.create_task(ws.send(audio.tobytes()))
|
||||
|
||||
# Start recording
|
||||
with sd.InputStream(callback=audio_callback,
|
||||
channels=1,
|
||||
samplerate=16000,
|
||||
blocksize=1600): # 0.1 second chunks
|
||||
|
||||
while True:
|
||||
response = await ws.recv()
|
||||
data = json.loads(response)
|
||||
if data.get('type') == 'transcript':
|
||||
print(f"→ {data['text']}")
|
||||
|
||||
asyncio.run(stream_microphone())
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
With GPU (GTX 1660):
|
||||
- **Latency**: <100ms per chunk
|
||||
- **Throughput**: ~50-100x realtime
|
||||
- **GPU Memory**: ~1.3GB
|
||||
- **Languages**: 25+ (auto-detected)
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Server won't start
|
||||
```bash
|
||||
# Check if port is in use
|
||||
lsof -i:8766
|
||||
|
||||
# Kill existing server
|
||||
pkill -f ws_server.py
|
||||
|
||||
# Restart
|
||||
./run.sh server/ws_server.py
|
||||
```
|
||||
|
||||
### Client can't connect
|
||||
```bash
|
||||
# Check server is running
|
||||
ps aux | grep ws_server
|
||||
|
||||
# Check firewall
|
||||
sudo ufw allow 8766
|
||||
```
|
||||
|
||||
### No transcription output
|
||||
- Check audio format (must be int16 PCM, 16kHz, mono)
|
||||
- Check chunk size (not too small)
|
||||
- Check server logs for errors
|
||||
|
||||
### GPU not working
|
||||
- Server will fall back to CPU automatically
|
||||
- Check `nvidia-smi` for GPU status
|
||||
- Verify CUDA libraries are loaded (should be automatic with `./run.sh`)
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. **Test the server**: `./run.sh test_client.py test.wav`
|
||||
2. **Try microphone**: `./run.sh client/mic_stream.py`
|
||||
3. **Build your own client** using the examples above
|
||||
|
||||
Happy transcribing! 🎤
|
||||
@@ -1,59 +0,0 @@
|
||||
# Parakeet ONNX ASR STT Container
|
||||
# Uses ONNX Runtime with CUDA for GPU-accelerated inference
|
||||
# Optimized for NVIDIA GTX 1660 and similar GPUs
|
||||
# Using CUDA 12.6 with cuDNN 9 for ONNX Runtime GPU support
|
||||
|
||||
FROM nvidia/cuda:12.6.2-cudnn-runtime-ubuntu22.04
|
||||
|
||||
# Prevent interactive prompts during build
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.11 \
|
||||
python3.11-venv \
|
||||
python3.11-dev \
|
||||
python3-pip \
|
||||
build-essential \
|
||||
ffmpeg \
|
||||
libsndfile1 \
|
||||
libportaudio2 \
|
||||
portaudio19-dev \
|
||||
git \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Upgrade pip to exact version used in requirements
|
||||
RUN python3.11 -m pip install --upgrade pip==25.3
|
||||
|
||||
# Copy requirements first (for Docker layer caching)
|
||||
COPY requirements-stt.txt .
|
||||
|
||||
# Install Python dependencies
|
||||
RUN python3.11 -m pip install --no-cache-dir -r requirements-stt.txt
|
||||
|
||||
# Copy application code
|
||||
COPY asr/ ./asr/
|
||||
COPY server/ ./server/
|
||||
COPY vad/ ./vad/
|
||||
COPY client/ ./client/
|
||||
|
||||
# Create models directory (models will be downloaded on first run)
|
||||
RUN mkdir -p models/parakeet
|
||||
|
||||
# Expose WebSocket port
|
||||
EXPOSE 8766
|
||||
|
||||
# Set GPU visibility (default to GPU 0)
|
||||
ENV CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
||||
CMD python3.11 -c "import onnxruntime as ort; assert 'CUDAExecutionProvider' in ort.get_available_providers()" || exit 1
|
||||
|
||||
# Run the WebSocket server
|
||||
CMD ["python3.11", "-m", "server.ws_server"]
|
||||
@@ -1,290 +0,0 @@
|
||||
# Quick Start Guide
|
||||
|
||||
## 🚀 Getting Started in 5 Minutes
|
||||
|
||||
### 1. Setup Environment
|
||||
|
||||
```bash
|
||||
# Make setup script executable and run it
|
||||
chmod +x setup_env.sh
|
||||
./setup_env.sh
|
||||
```
|
||||
|
||||
The setup script will:
|
||||
- Create a virtual environment
|
||||
- Install all dependencies including `onnx-asr`
|
||||
- Check CUDA/GPU availability
|
||||
- Run system diagnostics
|
||||
- Optionally download the Parakeet model
|
||||
|
||||
### 2. Activate Virtual Environment
|
||||
|
||||
```bash
|
||||
source venv/bin/activate
|
||||
```
|
||||
|
||||
### 3. Test Your Setup
|
||||
|
||||
Run diagnostics to verify everything is working:
|
||||
|
||||
```bash
|
||||
python3 tools/diagnose.py
|
||||
```
|
||||
|
||||
Expected output should show:
|
||||
- ✓ Python 3.10+
|
||||
- ✓ onnx-asr installed
|
||||
- ✓ CUDAExecutionProvider available
|
||||
- ✓ GPU detected
|
||||
|
||||
### 4. Test Offline Transcription
|
||||
|
||||
Create a test audio file or use an existing WAV file:
|
||||
|
||||
```bash
|
||||
python3 tools/test_offline.py test.wav
|
||||
```
|
||||
|
||||
### 5. Start Real-Time Streaming
|
||||
|
||||
**Terminal 1 - Start Server:**
|
||||
```bash
|
||||
python3 server/ws_server.py
|
||||
```
|
||||
|
||||
**Terminal 2 - Start Client:**
|
||||
```bash
|
||||
# List audio devices first
|
||||
python3 client/mic_stream.py --list-devices
|
||||
|
||||
# Start streaming with your microphone
|
||||
python3 client/mic_stream.py --device 0
|
||||
```
|
||||
|
||||
## 🎯 Common Commands
|
||||
|
||||
### Offline Transcription
|
||||
|
||||
```bash
|
||||
# Basic transcription
|
||||
python3 tools/test_offline.py audio.wav
|
||||
|
||||
# With Voice Activity Detection (for long files)
|
||||
python3 tools/test_offline.py audio.wav --use-vad
|
||||
|
||||
# With quantization (faster, uses less memory)
|
||||
python3 tools/test_offline.py audio.wav --quantization int8
|
||||
```
|
||||
|
||||
### WebSocket Server
|
||||
|
||||
```bash
|
||||
# Start server on default port (8765)
|
||||
python3 server/ws_server.py
|
||||
|
||||
# Custom host and port
|
||||
python3 server/ws_server.py --host 0.0.0.0 --port 9000
|
||||
|
||||
# With VAD enabled
|
||||
python3 server/ws_server.py --use-vad
|
||||
```
|
||||
|
||||
### Microphone Client
|
||||
|
||||
```bash
|
||||
# List available audio devices
|
||||
python3 client/mic_stream.py --list-devices
|
||||
|
||||
# Connect to server
|
||||
python3 client/mic_stream.py --url ws://localhost:8765
|
||||
|
||||
# Use specific device
|
||||
python3 client/mic_stream.py --device 2
|
||||
|
||||
# Custom sample rate
|
||||
python3 client/mic_stream.py --sample-rate 16000
|
||||
```
|
||||
|
||||
## 🔧 Troubleshooting
|
||||
|
||||
### GPU Not Detected
|
||||
|
||||
1. Check NVIDIA driver:
|
||||
```bash
|
||||
nvidia-smi
|
||||
```
|
||||
|
||||
2. Check CUDA version:
|
||||
```bash
|
||||
nvcc --version
|
||||
```
|
||||
|
||||
3. Verify ONNX Runtime can see GPU:
|
||||
```bash
|
||||
python3 -c "import onnxruntime as ort; print(ort.get_available_providers())"
|
||||
```
|
||||
|
||||
Should include `CUDAExecutionProvider`
|
||||
|
||||
### Out of Memory
|
||||
|
||||
If you get CUDA out of memory errors:
|
||||
|
||||
1. **Use quantization:**
|
||||
```bash
|
||||
python3 tools/test_offline.py audio.wav --quantization int8
|
||||
```
|
||||
|
||||
2. **Close other GPU applications**
|
||||
|
||||
3. **Reduce GPU memory limit** in `asr/asr_pipeline.py`:
|
||||
```python
|
||||
"gpu_mem_limit": 4 * 1024 * 1024 * 1024, # 4GB instead of 6GB
|
||||
```
|
||||
|
||||
### Microphone Not Working
|
||||
|
||||
1. Check permissions:
|
||||
```bash
|
||||
sudo usermod -a -G audio $USER
|
||||
# Then logout and login again
|
||||
```
|
||||
|
||||
2. Test with system audio recorder first
|
||||
|
||||
3. List and test devices:
|
||||
```bash
|
||||
python3 client/mic_stream.py --list-devices
|
||||
```
|
||||
|
||||
### Model Download Fails
|
||||
|
||||
If Hugging Face is slow or blocked:
|
||||
|
||||
1. **Set HF token** (optional, for faster downloads):
|
||||
```bash
|
||||
export HF_TOKEN="your_huggingface_token"
|
||||
```
|
||||
|
||||
2. **Manual download:**
|
||||
```bash
|
||||
# Download from: https://huggingface.co/istupakov/parakeet-tdt-0.6b-v3-onnx
|
||||
# Extract to: models/parakeet/
|
||||
```
|
||||
|
||||
## 📊 Performance Tips
|
||||
|
||||
### For Best GPU Performance
|
||||
|
||||
1. **Use TensorRT provider** (faster than CUDA):
|
||||
```bash
|
||||
pip install tensorrt tensorrt-cu12-libs
|
||||
```
|
||||
|
||||
Then edit `asr/asr_pipeline.py` to use TensorRT provider
|
||||
|
||||
2. **Use FP16 quantization** (on TensorRT):
|
||||
```python
|
||||
providers = [
|
||||
("TensorrtExecutionProvider", {
|
||||
"trt_fp16_enable": True,
|
||||
})
|
||||
]
|
||||
```
|
||||
|
||||
3. **Enable quantization:**
|
||||
```bash
|
||||
--quantization int8 # Good balance
|
||||
--quantization fp16 # Better quality
|
||||
```
|
||||
|
||||
### For Lower Latency Streaming
|
||||
|
||||
1. **Reduce chunk duration** in client:
|
||||
```bash
|
||||
python3 client/mic_stream.py --chunk-duration 0.05
|
||||
```
|
||||
|
||||
2. **Disable VAD** for shorter responses
|
||||
|
||||
3. **Use quantized model** for faster processing
|
||||
|
||||
## 🎤 Audio File Requirements
|
||||
|
||||
### Supported Formats
|
||||
- **Format**: WAV (PCM_16, PCM_24, PCM_32, PCM_U8)
|
||||
- **Sample Rate**: 16000 Hz (recommended)
|
||||
- **Channels**: Mono (stereo will be converted to mono)
|
||||
|
||||
### Convert Audio Files
|
||||
|
||||
```bash
|
||||
# Using ffmpeg
|
||||
ffmpeg -i input.mp3 -ar 16000 -ac 1 output.wav
|
||||
|
||||
# Using sox
|
||||
sox input.mp3 -r 16000 -c 1 output.wav
|
||||
```
|
||||
|
||||
## 📝 Example Workflow
|
||||
|
||||
Complete example for transcribing a meeting recording:
|
||||
|
||||
```bash
|
||||
# 1. Activate environment
|
||||
source venv/bin/activate
|
||||
|
||||
# 2. Convert audio to correct format
|
||||
ffmpeg -i meeting.mp3 -ar 16000 -ac 1 meeting.wav
|
||||
|
||||
# 3. Transcribe with VAD (for long recordings)
|
||||
python3 tools/test_offline.py meeting.wav --use-vad
|
||||
|
||||
# Output will show transcription with automatic segmentation
|
||||
```
|
||||
|
||||
## 🌐 Supported Languages
|
||||
|
||||
The Parakeet TDT 0.6B V3 model supports **25+ languages** including:
|
||||
- English
|
||||
- Spanish
|
||||
- French
|
||||
- German
|
||||
- Italian
|
||||
- Portuguese
|
||||
- Russian
|
||||
- Chinese
|
||||
- Japanese
|
||||
- Korean
|
||||
- And more...
|
||||
|
||||
The model automatically detects the language.
|
||||
|
||||
## 💡 Tips
|
||||
|
||||
1. **For short audio clips** (<30 seconds): Don't use VAD
|
||||
2. **For long audio files**: Use `--use-vad` flag
|
||||
3. **For real-time streaming**: Keep chunks small (0.1-0.5 seconds)
|
||||
4. **For best accuracy**: Use 16kHz mono WAV files
|
||||
5. **For faster inference**: Use `--quantization int8`
|
||||
|
||||
## 📚 More Information
|
||||
|
||||
- See `README.md` for detailed documentation
|
||||
- Run `python3 tools/diagnose.py` for system check
|
||||
- Check logs for debugging information
|
||||
|
||||
## 🆘 Getting Help
|
||||
|
||||
If you encounter issues:
|
||||
|
||||
1. Run diagnostics:
|
||||
```bash
|
||||
python3 tools/diagnose.py
|
||||
```
|
||||
|
||||
2. Check the logs in the terminal output
|
||||
|
||||
3. Verify your audio format and sample rate
|
||||
|
||||
4. Review the troubleshooting section above
|
||||
@@ -1,280 +0,0 @@
|
||||
# Parakeet ASR with ONNX Runtime
|
||||
|
||||
Real-time Automatic Speech Recognition (ASR) system using NVIDIA's Parakeet TDT 0.6B V3 model via the `onnx-asr` library, optimized for NVIDIA GPUs (GTX 1660 and better).
|
||||
|
||||
## Features
|
||||
|
||||
- ✅ **ONNX Runtime with GPU acceleration** (CUDA/TensorRT support)
|
||||
- ✅ **Parakeet TDT 0.6B V3** multilingual model from Hugging Face
|
||||
- ✅ **Real-time streaming** via WebSocket server
|
||||
- ✅ **Voice Activity Detection** (Silero VAD)
|
||||
- ✅ **Microphone client** for live transcription
|
||||
- ✅ **Offline transcription** from audio files
|
||||
- ✅ **Quantization support** (int8, fp16) for faster inference
|
||||
|
||||
## Model Information
|
||||
|
||||
This implementation uses:
|
||||
- **Model**: `nemo-parakeet-tdt-0.6b-v3` (Multilingual)
|
||||
- **Source**: https://huggingface.co/istupakov/parakeet-tdt-0.6b-v3-onnx
|
||||
- **Library**: https://github.com/istupakov/onnx-asr
|
||||
- **Original Model**: https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3
|
||||
|
||||
## System Requirements
|
||||
|
||||
- **GPU**: NVIDIA GPU with CUDA support (tested on GTX 1660)
|
||||
- **CUDA**: Version 11.8 or 12.x
|
||||
- **Python**: 3.10 or higher
|
||||
- **Memory**: At least 4GB GPU memory recommended
|
||||
|
||||
## Installation
|
||||
|
||||
### 1. Clone the repository
|
||||
|
||||
```bash
|
||||
cd /home/koko210Serve/parakeet-test
|
||||
```
|
||||
|
||||
### 2. Create virtual environment
|
||||
|
||||
```bash
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
```
|
||||
|
||||
### 3. Install CUDA dependencies
|
||||
|
||||
Make sure you have CUDA installed. For Ubuntu:
|
||||
|
||||
```bash
|
||||
# Check CUDA version
|
||||
nvcc --version
|
||||
|
||||
# If you need to install CUDA, follow NVIDIA's instructions:
|
||||
# https://developer.nvidia.com/cuda-downloads
|
||||
```
|
||||
|
||||
### 4. Install Python dependencies
|
||||
|
||||
```bash
|
||||
pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Or manually:
|
||||
|
||||
```bash
|
||||
# With GPU support (recommended)
|
||||
pip install onnx-asr[gpu,hub]
|
||||
|
||||
# Additional dependencies
|
||||
pip install numpy<2.0 websockets sounddevice soundfile
|
||||
```
|
||||
|
||||
### 5. Verify CUDA availability
|
||||
|
||||
```bash
|
||||
python3 -c "import onnxruntime as ort; print('Available providers:', ort.get_available_providers())"
|
||||
```
|
||||
|
||||
You should see `CUDAExecutionProvider` in the list.
|
||||
|
||||
## Usage
|
||||
|
||||
### Test Offline Transcription
|
||||
|
||||
Transcribe an audio file:
|
||||
|
||||
```bash
|
||||
python3 tools/test_offline.py test.wav
|
||||
```
|
||||
|
||||
With VAD (for long audio files):
|
||||
|
||||
```bash
|
||||
python3 tools/test_offline.py test.wav --use-vad
|
||||
```
|
||||
|
||||
With quantization (faster, less memory):
|
||||
|
||||
```bash
|
||||
python3 tools/test_offline.py test.wav --quantization int8
|
||||
```
|
||||
|
||||
### Start WebSocket Server
|
||||
|
||||
Start the ASR server:
|
||||
|
||||
```bash
|
||||
python3 server/ws_server.py
|
||||
```
|
||||
|
||||
With options:
|
||||
|
||||
```bash
|
||||
python3 server/ws_server.py --host 0.0.0.0 --port 8765 --use-vad
|
||||
```
|
||||
|
||||
### Start Microphone Client
|
||||
|
||||
In a separate terminal, start the microphone client:
|
||||
|
||||
```bash
|
||||
python3 client/mic_stream.py
|
||||
```
|
||||
|
||||
List available audio devices:
|
||||
|
||||
```bash
|
||||
python3 client/mic_stream.py --list-devices
|
||||
```
|
||||
|
||||
Connect to a specific device:
|
||||
|
||||
```bash
|
||||
python3 client/mic_stream.py --device 0
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
parakeet-test/
|
||||
├── asr/
|
||||
│ ├── __init__.py
|
||||
│ └── asr_pipeline.py # Main ASR pipeline using onnx-asr
|
||||
├── client/
|
||||
│ ├── __init__.py
|
||||
│ └── mic_stream.py # Microphone streaming client
|
||||
├── server/
|
||||
│ ├── __init__.py
|
||||
│ └── ws_server.py # WebSocket server for streaming ASR
|
||||
├── vad/
|
||||
│ ├── __init__.py
|
||||
│ └── silero_vad.py # VAD wrapper using onnx-asr
|
||||
├── tools/
|
||||
│ ├── test_offline.py # Test offline transcription
|
||||
│ └── diagnose.py # System diagnostics
|
||||
├── models/
|
||||
│ └── parakeet/ # Model files (auto-downloaded)
|
||||
├── requirements.txt # Python dependencies
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
## Model Files
|
||||
|
||||
The model files will be automatically downloaded from Hugging Face on first run to:
|
||||
```
|
||||
models/parakeet/
|
||||
├── config.json
|
||||
├── encoder-parakeet-tdt-0.6b-v3.onnx
|
||||
├── decoder_joint-parakeet-tdt-0.6b-v3.onnx
|
||||
└── vocab.txt
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### GPU Settings
|
||||
|
||||
The ASR pipeline is configured to use CUDA by default. You can customize the execution providers in `asr/asr_pipeline.py`:
|
||||
|
||||
```python
|
||||
providers = [
|
||||
(
|
||||
"CUDAExecutionProvider",
|
||||
{
|
||||
"device_id": 0,
|
||||
"arena_extend_strategy": "kNextPowerOfTwo",
|
||||
"gpu_mem_limit": 6 * 1024 * 1024 * 1024, # 6GB
|
||||
"cudnn_conv_algo_search": "EXHAUSTIVE",
|
||||
"do_copy_in_default_stream": True,
|
||||
}
|
||||
),
|
||||
"CPUExecutionProvider",
|
||||
]
|
||||
```
|
||||
|
||||
### TensorRT (Optional - Faster Inference)
|
||||
|
||||
For even better performance, you can use TensorRT:
|
||||
|
||||
```bash
|
||||
pip install tensorrt tensorrt-cu12-libs
|
||||
```
|
||||
|
||||
Then modify the providers:
|
||||
|
||||
```python
|
||||
providers = [
|
||||
(
|
||||
"TensorrtExecutionProvider",
|
||||
{
|
||||
"trt_max_workspace_size": 6 * 1024**3,
|
||||
"trt_fp16_enable": True,
|
||||
},
|
||||
)
|
||||
]
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### CUDA Not Available
|
||||
|
||||
If CUDA is not detected:
|
||||
|
||||
1. Check CUDA installation: `nvcc --version`
|
||||
2. Verify GPU: `nvidia-smi`
|
||||
3. Reinstall onnxruntime-gpu:
|
||||
```bash
|
||||
pip uninstall onnxruntime onnxruntime-gpu
|
||||
pip install onnxruntime-gpu
|
||||
```
|
||||
|
||||
### Memory Issues
|
||||
|
||||
If you run out of GPU memory:
|
||||
|
||||
1. Use quantization: `--quantization int8`
|
||||
2. Reduce `gpu_mem_limit` in the configuration
|
||||
3. Close other GPU-using applications
|
||||
|
||||
### Audio Issues
|
||||
|
||||
If microphone is not working:
|
||||
|
||||
1. List devices: `python3 client/mic_stream.py --list-devices`
|
||||
2. Select the correct device: `--device <id>`
|
||||
3. Check permissions: `sudo usermod -a -G audio $USER` (then logout/login)
|
||||
|
||||
### Slow Performance
|
||||
|
||||
1. Ensure GPU is being used (check logs for "CUDAExecutionProvider")
|
||||
2. Try quantization for faster inference
|
||||
3. Consider using TensorRT provider
|
||||
4. Check GPU utilization: `nvidia-smi`
|
||||
|
||||
## Performance
|
||||
|
||||
Expected performance on GTX 1660 (6GB):
|
||||
|
||||
- **Offline transcription**: ~50-100x realtime (depending on audio length)
|
||||
- **Streaming**: <100ms latency
|
||||
- **Memory usage**: ~2-3GB GPU memory
|
||||
- **Quantized (int8)**: ~30% faster, ~50% less memory
|
||||
|
||||
## License
|
||||
|
||||
This project uses:
|
||||
- `onnx-asr`: MIT License
|
||||
- Parakeet model: CC-BY-4.0 License
|
||||
|
||||
## References
|
||||
|
||||
- [onnx-asr GitHub](https://github.com/istupakov/onnx-asr)
|
||||
- [Parakeet TDT 0.6B V3 ONNX](https://huggingface.co/istupakov/parakeet-tdt-0.6b-v3-onnx)
|
||||
- [NVIDIA Parakeet](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3)
|
||||
- [ONNX Runtime](https://onnxruntime.ai/)
|
||||
|
||||
## Credits
|
||||
|
||||
- Model conversion by [istupakov](https://github.com/istupakov)
|
||||
- Original Parakeet model by NVIDIA
|
||||
@@ -1,244 +0,0 @@
|
||||
# Refactoring Summary
|
||||
|
||||
## Overview
|
||||
|
||||
Successfully refactored the Parakeet ASR codebase to use the `onnx-asr` library with ONNX Runtime GPU support for NVIDIA GTX 1660.
|
||||
|
||||
## Changes Made
|
||||
|
||||
### 1. Dependencies (`requirements.txt`)
|
||||
- **Removed**: `onnxruntime-gpu`, `silero-vad`
|
||||
- **Added**: `onnx-asr[gpu,hub]`, `soundfile`
|
||||
- **Kept**: `numpy<2.0`, `websockets`, `sounddevice`
|
||||
|
||||
### 2. ASR Pipeline (`asr/asr_pipeline.py`)
|
||||
- Completely refactored to use `onnx_asr.load_model()`
|
||||
- Added support for:
|
||||
- GPU acceleration via CUDA/TensorRT
|
||||
- Model quantization (int8, fp16)
|
||||
- Voice Activity Detection (VAD)
|
||||
- Batch processing
|
||||
- Streaming audio chunks
|
||||
- Configurable execution providers for GPU optimization
|
||||
- Automatic model download from Hugging Face
|
||||
|
||||
### 3. VAD Module (`vad/silero_vad.py`)
|
||||
- Refactored to use `onnx_asr.load_vad()`
|
||||
- Integrated Silero VAD via onnx-asr
|
||||
- Simplified API for VAD operations
|
||||
- Note: VAD is best used via `model.with_vad()` method
|
||||
|
||||
### 4. WebSocket Server (`server/ws_server.py`)
|
||||
- Created from scratch for streaming ASR
|
||||
- Features:
|
||||
- Real-time audio streaming
|
||||
- JSON-based protocol
|
||||
- Support for multiple concurrent connections
|
||||
- Buffer management for audio chunks
|
||||
- Error handling and logging
|
||||
|
||||
### 5. Microphone Client (`client/mic_stream.py`)
|
||||
- Created streaming client using `sounddevice`
|
||||
- Features:
|
||||
- Real-time microphone capture
|
||||
- WebSocket streaming to server
|
||||
- Audio device selection
|
||||
- Automatic format conversion (float32 to int16)
|
||||
- Async communication
|
||||
|
||||
### 6. Test Script (`tools/test_offline.py`)
|
||||
- Completely rewritten for onnx-asr
|
||||
- Features:
|
||||
- Command-line interface
|
||||
- Support for WAV files
|
||||
- Optional VAD and quantization
|
||||
- Audio statistics and diagnostics
|
||||
|
||||
### 7. Diagnostics Tool (`tools/diagnose.py`)
|
||||
- New comprehensive system check tool
|
||||
- Checks:
|
||||
- Python version
|
||||
- Installed packages
|
||||
- CUDA availability
|
||||
- ONNX Runtime providers
|
||||
- Audio devices
|
||||
- Model files
|
||||
|
||||
### 8. Setup Script (`setup_env.sh`)
|
||||
- Automated setup script
|
||||
- Features:
|
||||
- Virtual environment creation
|
||||
- Dependency installation
|
||||
- CUDA/GPU detection
|
||||
- System diagnostics
|
||||
- Optional model download
|
||||
|
||||
### 9. Documentation
|
||||
- **README.md**: Comprehensive documentation with:
|
||||
- Installation instructions
|
||||
- Usage examples
|
||||
- Configuration options
|
||||
- Troubleshooting guide
|
||||
- Performance tips
|
||||
|
||||
- **QUICKSTART.md**: Quick start guide with:
|
||||
- 5-minute setup
|
||||
- Common commands
|
||||
- Troubleshooting
|
||||
- Performance optimization
|
||||
|
||||
- **example.py**: Simple usage example
|
||||
|
||||
## Key Benefits
|
||||
|
||||
### 1. GPU Optimization
|
||||
- Native CUDA support via ONNX Runtime
|
||||
- Configurable GPU memory limits
|
||||
- Optional TensorRT for even faster inference
|
||||
- Automatic fallback to CPU if GPU unavailable
|
||||
|
||||
### 2. Simplified Model Management
|
||||
- Automatic model download from Hugging Face
|
||||
- No manual ONNX export needed
|
||||
- Pre-converted models ready to use
|
||||
- Support for quantized versions
|
||||
|
||||
### 3. Better Performance
|
||||
- Optimized ONNX inference
|
||||
- GPU acceleration on GTX 1660
|
||||
- ~50-100x realtime on GPU
|
||||
- Reduced memory usage with quantization
|
||||
|
||||
### 4. Improved Usability
|
||||
- Simpler API
|
||||
- Better error handling
|
||||
- Comprehensive logging
|
||||
- Easy configuration
|
||||
|
||||
### 5. Modern Features
|
||||
- WebSocket streaming
|
||||
- Real-time transcription
|
||||
- VAD integration
|
||||
- Batch processing
|
||||
|
||||
## Model Information
|
||||
|
||||
- **Model**: Parakeet TDT 0.6B V3 (Multilingual)
|
||||
- **Source**: https://huggingface.co/istupakov/parakeet-tdt-0.6b-v3-onnx
|
||||
- **Size**: ~600MB
|
||||
- **Languages**: 25+ languages
|
||||
- **Location**: `models/parakeet/` (auto-downloaded)
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
parakeet-test/
|
||||
├── asr/
|
||||
│ ├── __init__.py ✓ Updated
|
||||
│ └── asr_pipeline.py ✓ Refactored
|
||||
├── client/
|
||||
│ ├── __init__.py ✓ Updated
|
||||
│ └── mic_stream.py ✓ New
|
||||
├── server/
|
||||
│ ├── __init__.py ✓ Updated
|
||||
│ └── ws_server.py ✓ New
|
||||
├── vad/
|
||||
│ ├── __init__.py ✓ Updated
|
||||
│ └── silero_vad.py ✓ Refactored
|
||||
├── tools/
|
||||
│ ├── diagnose.py ✓ New
|
||||
│ └── test_offline.py ✓ Refactored
|
||||
├── models/
|
||||
│ └── parakeet/ ✓ Auto-created
|
||||
├── requirements.txt ✓ Updated
|
||||
├── setup_env.sh ✓ New
|
||||
├── README.md ✓ New
|
||||
├── QUICKSTART.md ✓ New
|
||||
├── example.py ✓ New
|
||||
├── .gitignore ✓ New
|
||||
└── REFACTORING.md ✓ This file
|
||||
```
|
||||
|
||||
## Migration from Old Code
|
||||
|
||||
### Old Code Pattern:
|
||||
```python
|
||||
# Manual ONNX session creation
|
||||
import onnxruntime as ort
|
||||
session = ort.InferenceSession("encoder.onnx", providers=["CUDAExecutionProvider"])
|
||||
# Manual preprocessing and decoding
|
||||
```
|
||||
|
||||
### New Code Pattern:
|
||||
```python
|
||||
# Simple onnx-asr interface
|
||||
import onnx_asr
|
||||
model = onnx_asr.load_model("nemo-parakeet-tdt-0.6b-v3")
|
||||
text = model.recognize("audio.wav")
|
||||
```
|
||||
|
||||
## Testing Instructions
|
||||
|
||||
### 1. Setup
|
||||
```bash
|
||||
./setup_env.sh
|
||||
source venv/bin/activate
|
||||
```
|
||||
|
||||
### 2. Run Diagnostics
|
||||
```bash
|
||||
python3 tools/diagnose.py
|
||||
```
|
||||
|
||||
### 3. Test Offline
|
||||
```bash
|
||||
python3 tools/test_offline.py test.wav
|
||||
```
|
||||
|
||||
### 4. Test Streaming
|
||||
```bash
|
||||
# Terminal 1
|
||||
python3 server/ws_server.py
|
||||
|
||||
# Terminal 2
|
||||
python3 client/mic_stream.py
|
||||
```
|
||||
|
||||
## Known Limitations
|
||||
|
||||
1. **Audio Format**: Only WAV files with PCM encoding supported directly
|
||||
2. **Segment Length**: Models work best with <30 second segments
|
||||
3. **GPU Memory**: Requires at least 2-3GB GPU memory
|
||||
4. **Sample Rate**: 16kHz recommended for best results
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
Possible improvements:
|
||||
- [ ] Add support for other audio formats (MP3, FLAC, etc.)
|
||||
- [ ] Implement beam search decoding
|
||||
- [ ] Add language selection option
|
||||
- [ ] Support for speaker diarization
|
||||
- [ ] REST API in addition to WebSocket
|
||||
- [ ] Docker containerization
|
||||
- [ ] Batch file processing script
|
||||
- [ ] Real-time visualization of transcription
|
||||
|
||||
## References
|
||||
|
||||
- [onnx-asr GitHub](https://github.com/istupakov/onnx-asr)
|
||||
- [onnx-asr Documentation](https://istupakov.github.io/onnx-asr/)
|
||||
- [Parakeet ONNX Model](https://huggingface.co/istupakov/parakeet-tdt-0.6b-v3-onnx)
|
||||
- [Original Parakeet Model](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3)
|
||||
- [ONNX Runtime](https://onnxruntime.ai/)
|
||||
|
||||
## Support
|
||||
|
||||
For issues related to:
|
||||
- **onnx-asr library**: https://github.com/istupakov/onnx-asr/issues
|
||||
- **This implementation**: Check logs and run diagnose.py
|
||||
- **GPU/CUDA issues**: Verify nvidia-smi and CUDA installation
|
||||
|
||||
---
|
||||
|
||||
**Refactoring completed on**: January 18, 2026
|
||||
**Primary changes**: Migration to onnx-asr library for simplified ONNX inference with GPU support
|
||||
@@ -1,337 +0,0 @@
|
||||
# Remote Microphone Streaming Setup
|
||||
|
||||
This guide shows how to use the ASR system with a client on one machine streaming audio to a server on another machine.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────────┐ ┌─────────────────┐
|
||||
│ Client Machine │ │ Server Machine │
|
||||
│ │ │ │
|
||||
│ 🎤 Microphone │ ───WebSocket───▶ │ 🖥️ Display │
|
||||
│ │ (Audio) │ │
|
||||
│ client/ │ │ server/ │
|
||||
│ mic_stream.py │ │ display_server │
|
||||
└─────────────────┘ └─────────────────┘
|
||||
```
|
||||
|
||||
## Server Setup (Machine with GPU)
|
||||
|
||||
### 1. Start the server with live display
|
||||
|
||||
```bash
|
||||
cd /home/koko210Serve/parakeet-test
|
||||
source venv/bin/activate
|
||||
PYTHONPATH=/home/koko210Serve/parakeet-test python server/display_server.py
|
||||
```
|
||||
|
||||
**Options:**
|
||||
```bash
|
||||
python server/display_server.py --host 0.0.0.0 --port 8766
|
||||
```
|
||||
|
||||
The server will:
|
||||
- ✅ Bind to all network interfaces (0.0.0.0)
|
||||
- ✅ Display transcriptions in real-time with color coding
|
||||
- ✅ Show progressive updates as audio streams in
|
||||
- ✅ Highlight final transcriptions when complete
|
||||
|
||||
### 2. Configure firewall (if needed)
|
||||
|
||||
Allow incoming connections on port 8766:
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo ufw allow 8766/tcp
|
||||
|
||||
# CentOS/RHEL
|
||||
sudo firewall-cmd --permanent --add-port=8766/tcp
|
||||
sudo firewall-cmd --reload
|
||||
```
|
||||
|
||||
### 3. Get the server's IP address
|
||||
|
||||
```bash
|
||||
# Find your server's IP address
|
||||
ip addr show | grep "inet " | grep -v 127.0.0.1
|
||||
```
|
||||
|
||||
Example output: `192.168.1.100`
|
||||
|
||||
## Client Setup (Remote Machine)
|
||||
|
||||
### 1. Install dependencies on client machine
|
||||
|
||||
Create a minimal Python environment:
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
python3 -m venv asr-client
|
||||
source asr-client/bin/activate
|
||||
|
||||
# Install only client dependencies
|
||||
pip install websockets sounddevice numpy
|
||||
```
|
||||
|
||||
### 2. Copy the client script
|
||||
|
||||
Copy `client/mic_stream.py` to your client machine:
|
||||
|
||||
```bash
|
||||
# On server machine
|
||||
scp client/mic_stream.py user@client-machine:~/
|
||||
|
||||
# Or download it via your preferred method
|
||||
```
|
||||
|
||||
### 3. List available microphones
|
||||
|
||||
```bash
|
||||
python mic_stream.py --list-devices
|
||||
```
|
||||
|
||||
Example output:
|
||||
```
|
||||
Available audio input devices:
|
||||
--------------------------------------------------------------------------------
|
||||
[0] Built-in Microphone
|
||||
Channels: 2
|
||||
Sample rate: 44100.0 Hz
|
||||
[1] USB Microphone
|
||||
Channels: 1
|
||||
Sample rate: 48000.0 Hz
|
||||
--------------------------------------------------------------------------------
|
||||
```
|
||||
|
||||
### 4. Start streaming
|
||||
|
||||
```bash
|
||||
python mic_stream.py --url ws://SERVER_IP:8766
|
||||
```
|
||||
|
||||
Replace `SERVER_IP` with your server's IP address (e.g., `ws://192.168.1.100:8766`)
|
||||
|
||||
**Options:**
|
||||
```bash
|
||||
# Use specific microphone device
|
||||
python mic_stream.py --url ws://192.168.1.100:8766 --device 1
|
||||
|
||||
# Change sample rate (if needed)
|
||||
python mic_stream.py --url ws://192.168.1.100:8766 --sample-rate 16000
|
||||
|
||||
# Adjust chunk size for network latency
|
||||
python mic_stream.py --url ws://192.168.1.100:8766 --chunk-duration 0.2
|
||||
```
|
||||
|
||||
## Usage Flow
|
||||
|
||||
### 1. Start Server
|
||||
On the server machine:
|
||||
```bash
|
||||
cd /home/koko210Serve/parakeet-test
|
||||
source venv/bin/activate
|
||||
PYTHONPATH=/home/koko210Serve/parakeet-test python server/display_server.py
|
||||
```
|
||||
|
||||
You'll see:
|
||||
```
|
||||
================================================================================
|
||||
ASR Server - Live Transcription Display
|
||||
================================================================================
|
||||
Server: ws://0.0.0.0:8766
|
||||
Sample Rate: 16000 Hz
|
||||
Model: Parakeet TDT 0.6B V3
|
||||
================================================================================
|
||||
|
||||
Server is running and ready for connections!
|
||||
Waiting for clients...
|
||||
```
|
||||
|
||||
### 2. Connect Client
|
||||
On the client machine:
|
||||
```bash
|
||||
python mic_stream.py --url ws://192.168.1.100:8766
|
||||
```
|
||||
|
||||
You'll see:
|
||||
```
|
||||
Connected to server: ws://192.168.1.100:8766
|
||||
Recording started. Press Ctrl+C to stop.
|
||||
```
|
||||
|
||||
### 3. Speak into Microphone
|
||||
- Speak naturally into your microphone
|
||||
- Watch the **server terminal** for real-time transcriptions
|
||||
- Progressive updates appear in yellow as you speak
|
||||
- Final transcriptions appear in green when you pause
|
||||
|
||||
### 4. Stop Streaming
|
||||
Press `Ctrl+C` on the client to stop recording and disconnect.
|
||||
|
||||
## Display Color Coding
|
||||
|
||||
On the server display:
|
||||
|
||||
- **🟢 GREEN** = Final transcription (complete, accurate)
|
||||
- **🟡 YELLOW** = Progressive update (in progress)
|
||||
- **🔵 BLUE** = Connection events
|
||||
- **⚪ WHITE** = Server status messages
|
||||
|
||||
## Example Session
|
||||
|
||||
### Server Display:
|
||||
```
|
||||
================================================================================
|
||||
✓ Client connected: 192.168.1.50:45232
|
||||
================================================================================
|
||||
|
||||
[14:23:15] 192.168.1.50:45232
|
||||
→ Hello this is
|
||||
|
||||
[14:23:17] 192.168.1.50:45232
|
||||
→ Hello this is a test of the remote
|
||||
|
||||
[14:23:19] 192.168.1.50:45232
|
||||
✓ FINAL: Hello this is a test of the remote microphone streaming system.
|
||||
|
||||
[14:23:25] 192.168.1.50:45232
|
||||
→ Can you hear me
|
||||
|
||||
[14:23:27] 192.168.1.50:45232
|
||||
✓ FINAL: Can you hear me clearly?
|
||||
|
||||
================================================================================
|
||||
✗ Client disconnected: 192.168.1.50:45232
|
||||
================================================================================
|
||||
```
|
||||
|
||||
### Client Display:
|
||||
```
|
||||
Connected to server: ws://192.168.1.100:8766
|
||||
Recording started. Press Ctrl+C to stop.
|
||||
|
||||
Server: Connected to ASR server with live display
|
||||
[PARTIAL] Hello this is
|
||||
[PARTIAL] Hello this is a test of the remote
|
||||
[FINAL] Hello this is a test of the remote microphone streaming system.
|
||||
[PARTIAL] Can you hear me
|
||||
[FINAL] Can you hear me clearly?
|
||||
|
||||
^C
|
||||
Stopped by user
|
||||
Disconnected from server
|
||||
Client stopped by user
|
||||
```
|
||||
|
||||
## Network Considerations
|
||||
|
||||
### Bandwidth Usage
|
||||
- Sample rate: 16000 Hz
|
||||
- Bit depth: 16-bit (int16)
|
||||
- Bandwidth: ~32 KB/s per client
|
||||
- Very low bandwidth - works well over WiFi or LAN
|
||||
|
||||
### Latency
|
||||
- Progressive updates: Every ~2 seconds
|
||||
- Final transcription: When audio stops or on demand
|
||||
- Total latency: ~2-3 seconds (network + processing)
|
||||
|
||||
### Multiple Clients
|
||||
The server supports multiple simultaneous clients:
|
||||
- Each client gets its own session
|
||||
- Transcriptions are tagged with client IP:port
|
||||
- No interference between clients
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Client Can't Connect
|
||||
```
|
||||
Error: [Errno 111] Connection refused
|
||||
```
|
||||
**Solution:**
|
||||
1. Check server is running
|
||||
2. Verify firewall allows port 8766
|
||||
3. Confirm server IP address is correct
|
||||
4. Test connectivity: `ping SERVER_IP`
|
||||
|
||||
### No Audio Being Captured
|
||||
```
|
||||
Recording started but no transcriptions appear
|
||||
```
|
||||
**Solution:**
|
||||
1. Check microphone permissions
|
||||
2. List devices: `python mic_stream.py --list-devices`
|
||||
3. Try different device: `--device N`
|
||||
4. Test microphone in other apps first
|
||||
|
||||
### Poor Transcription Quality
|
||||
**Solution:**
|
||||
1. Move closer to microphone
|
||||
2. Reduce background noise
|
||||
3. Speak clearly and at normal pace
|
||||
4. Check microphone quality/settings
|
||||
|
||||
### High Latency
|
||||
**Solution:**
|
||||
1. Use wired connection instead of WiFi
|
||||
2. Reduce chunk duration: `--chunk-duration 0.05`
|
||||
3. Check network latency: `ping SERVER_IP`
|
||||
|
||||
## Security Notes
|
||||
|
||||
⚠️ **Important:** This setup uses WebSocket without encryption (ws://)
|
||||
|
||||
For production use:
|
||||
- Use WSS (WebSocket Secure) with TLS certificates
|
||||
- Add authentication (API keys, tokens)
|
||||
- Restrict firewall rules to specific IP ranges
|
||||
- Consider using VPN for remote access
|
||||
|
||||
## Advanced: Auto-start Server
|
||||
|
||||
Create a systemd service (Linux):
|
||||
|
||||
```bash
|
||||
sudo nano /etc/systemd/system/asr-server.service
|
||||
```
|
||||
|
||||
```ini
|
||||
[Unit]
|
||||
Description=ASR WebSocket Server
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=YOUR_USERNAME
|
||||
WorkingDirectory=/home/koko210Serve/parakeet-test
|
||||
Environment="PYTHONPATH=/home/koko210Serve/parakeet-test"
|
||||
ExecStart=/home/koko210Serve/parakeet-test/venv/bin/python server/display_server.py
|
||||
Restart=always
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
```
|
||||
|
||||
Enable and start:
|
||||
```bash
|
||||
sudo systemctl enable asr-server
|
||||
sudo systemctl start asr-server
|
||||
sudo systemctl status asr-server
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
1. **Server:** Use GPU for best performance (~100ms latency)
|
||||
2. **Client:** Use low chunk duration for responsiveness (0.1s default)
|
||||
3. **Network:** Wired connection preferred, WiFi works fine
|
||||
4. **Audio Quality:** 16kHz sample rate is optimal for speech
|
||||
|
||||
## Summary
|
||||
|
||||
✅ **Server displays transcriptions in real-time**
|
||||
✅ **Client sends audio from remote microphone**
|
||||
✅ **Progressive updates show live transcription**
|
||||
✅ **Final results when speech pauses**
|
||||
✅ **Multiple clients supported**
|
||||
✅ **Low bandwidth, low latency**
|
||||
|
||||
Enjoy your remote ASR streaming system! 🎤 → 🌐 → 🖥️
|
||||
@@ -1,155 +0,0 @@
|
||||
# Parakeet ASR - Setup Complete! ✅
|
||||
|
||||
## Summary
|
||||
|
||||
Successfully set up Parakeet ASR with ONNX Runtime and GPU support on your GTX 1660!
|
||||
|
||||
## What Was Done
|
||||
|
||||
### 1. Fixed Python Version
|
||||
- Removed Python 3.14 virtual environment
|
||||
- Created new venv with Python 3.11.14 (compatible with onnxruntime-gpu)
|
||||
|
||||
### 2. Installed Dependencies
|
||||
- `onnx-asr[gpu,hub]` - Main ASR library
|
||||
- `onnxruntime-gpu` 1.23.2 - GPU-accelerated inference
|
||||
- `numpy<2.0` - Numerical computing
|
||||
- `websockets` - WebSocket support
|
||||
- `sounddevice` - Audio capture
|
||||
- `soundfile` - Audio file I/O
|
||||
- CUDA 12 libraries via pip (nvidia-cublas-cu12, nvidia-cudnn-cu12)
|
||||
|
||||
### 3. Downloaded Model Files
|
||||
All model files (~2.4GB) downloaded from HuggingFace:
|
||||
- `encoder-model.onnx` (40MB)
|
||||
- `encoder-model.onnx.data` (2.3GB)
|
||||
- `decoder_joint-model.onnx` (70MB)
|
||||
- `config.json`
|
||||
- `vocab.txt`
|
||||
- `nemo128.onnx`
|
||||
|
||||
### 4. Tested Successfully
|
||||
✅ Offline transcription working with GPU
|
||||
✅ Model: Parakeet TDT 0.6B V3 (Multilingual)
|
||||
✅ GPU Memory Usage: ~1.3GB
|
||||
✅ Tested on test.wav - Perfect transcription!
|
||||
|
||||
## How to Use
|
||||
|
||||
### Quick Test
|
||||
```bash
|
||||
./run.sh tools/test_offline.py test.wav
|
||||
```
|
||||
|
||||
### With VAD (for long files)
|
||||
```bash
|
||||
./run.sh tools/test_offline.py your_audio.wav --use-vad
|
||||
```
|
||||
|
||||
### With Quantization (faster)
|
||||
```bash
|
||||
./run.sh tools/test_offline.py your_audio.wav --quantization int8
|
||||
```
|
||||
|
||||
### Start Server
|
||||
```bash
|
||||
./run.sh server/ws_server.py
|
||||
```
|
||||
|
||||
### Start Microphone Client
|
||||
```bash
|
||||
./run.sh client/mic_stream.py
|
||||
```
|
||||
|
||||
### List Audio Devices
|
||||
```bash
|
||||
./run.sh client/mic_stream.py --list-devices
|
||||
```
|
||||
|
||||
## System Info
|
||||
|
||||
- **Python**: 3.11.14
|
||||
- **GPU**: NVIDIA GeForce GTX 1660 (6GB)
|
||||
- **CUDA**: 13.1 (using CUDA 12 compatibility libs)
|
||||
- **ONNX Runtime**: 1.23.2 with GPU support
|
||||
- **Model**: nemo-parakeet-tdt-0.6b-v3 (Multilingual, 25+ languages)
|
||||
|
||||
## GPU Status
|
||||
|
||||
The GPU is working! ONNX Runtime is using:
|
||||
- CUDAExecutionProvider ✅
|
||||
- TensorrtExecutionProvider ✅
|
||||
- CPUExecutionProvider (fallback)
|
||||
|
||||
Current GPU usage: ~1.3GB during inference
|
||||
|
||||
## Performance
|
||||
|
||||
With GPU acceleration on GTX 1660:
|
||||
- **Offline**: ~50-100x realtime
|
||||
- **Latency**: <100ms for streaming
|
||||
- **Memory**: 2-3GB GPU RAM
|
||||
|
||||
## Files Structure
|
||||
|
||||
```
|
||||
parakeet-test/
|
||||
├── run.sh ← Use this to run scripts!
|
||||
├── asr/ ← ASR pipeline
|
||||
├── client/ ← Microphone client
|
||||
├── server/ ← WebSocket server
|
||||
├── tools/ ← Testing tools
|
||||
├── venv/ ← Python 3.11 environment
|
||||
└── models/parakeet/ ← Downloaded model files
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- Use `./run.sh` to run any Python script (sets up CUDA paths automatically)
|
||||
- Model supports 25+ languages (auto-detected)
|
||||
- For best performance, use 16kHz mono WAV files
|
||||
- GPU is working despite CUDA version difference (13.1 vs 12)
|
||||
|
||||
## Next Steps
|
||||
|
||||
Want to do more?
|
||||
|
||||
1. **Test streaming**:
|
||||
```bash
|
||||
# Terminal 1
|
||||
./run.sh server/ws_server.py
|
||||
|
||||
# Terminal 2
|
||||
./run.sh client/mic_stream.py
|
||||
```
|
||||
|
||||
2. **Try quantization** for 30% speed boost:
|
||||
```bash
|
||||
./run.sh tools/test_offline.py audio.wav --quantization int8
|
||||
```
|
||||
|
||||
3. **Process multiple files**:
|
||||
```bash
|
||||
for file in *.wav; do
|
||||
./run.sh tools/test_offline.py "$file"
|
||||
done
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
If GPU stops working:
|
||||
```bash
|
||||
# Check GPU
|
||||
nvidia-smi
|
||||
|
||||
# Verify ONNX providers
|
||||
./run.sh -c "import onnxruntime as ort; print(ort.get_available_providers())"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
**Status**: ✅ WORKING PERFECTLY
|
||||
**GPU**: ✅ ACTIVE
|
||||
**Performance**: ✅ EXCELLENT
|
||||
|
||||
Enjoy your GPU-accelerated speech recognition! 🚀
|
||||
@@ -1,6 +0,0 @@
|
||||
"""
|
||||
ASR module using onnx-asr library
|
||||
"""
|
||||
from .asr_pipeline import ASRPipeline, load_pipeline
|
||||
|
||||
__all__ = ["ASRPipeline", "load_pipeline"]
|
||||
@@ -1,162 +0,0 @@
|
||||
"""
|
||||
ASR Pipeline using onnx-asr library with Parakeet TDT 0.6B V3 model
|
||||
"""
|
||||
import numpy as np
|
||||
import onnx_asr
|
||||
from typing import Union, Optional
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ASRPipeline:
|
||||
"""
|
||||
ASR Pipeline wrapper for onnx-asr Parakeet TDT model.
|
||||
Supports GPU acceleration via ONNX Runtime with CUDA/TensorRT.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "nemo-parakeet-tdt-0.6b-v3",
|
||||
model_path: Optional[str] = None,
|
||||
quantization: Optional[str] = None,
|
||||
providers: Optional[list] = None,
|
||||
use_vad: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize ASR Pipeline.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to load (default: "nemo-parakeet-tdt-0.6b-v3")
|
||||
model_path: Optional local path to model files (default uses models/parakeet)
|
||||
quantization: Optional quantization ("int8", "fp16", etc.)
|
||||
providers: Optional ONNX runtime providers list for GPU acceleration
|
||||
use_vad: Whether to use Voice Activity Detection
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.model_path = model_path or "models/parakeet"
|
||||
self.quantization = quantization
|
||||
self.use_vad = use_vad
|
||||
|
||||
# Configure providers for GPU acceleration
|
||||
if providers is None:
|
||||
# Default: try CUDA, then CPU
|
||||
providers = [
|
||||
(
|
||||
"CUDAExecutionProvider",
|
||||
{
|
||||
"device_id": 0,
|
||||
"arena_extend_strategy": "kNextPowerOfTwo",
|
||||
"gpu_mem_limit": 6 * 1024 * 1024 * 1024, # 6GB
|
||||
"cudnn_conv_algo_search": "EXHAUSTIVE",
|
||||
"do_copy_in_default_stream": True,
|
||||
}
|
||||
),
|
||||
"CPUExecutionProvider",
|
||||
]
|
||||
|
||||
self.providers = providers
|
||||
logger.info(f"Initializing ASR Pipeline with model: {model_name}")
|
||||
logger.info(f"Model path: {self.model_path}")
|
||||
logger.info(f"Quantization: {quantization}")
|
||||
logger.info(f"Providers: {providers}")
|
||||
|
||||
# Load the model
|
||||
try:
|
||||
self.model = onnx_asr.load_model(
|
||||
model_name,
|
||||
self.model_path,
|
||||
quantization=quantization,
|
||||
providers=providers,
|
||||
)
|
||||
logger.info("Model loaded successfully")
|
||||
|
||||
# Optionally add VAD
|
||||
if use_vad:
|
||||
logger.info("Loading VAD model...")
|
||||
vad = onnx_asr.load_vad("silero", providers=providers)
|
||||
self.model = self.model.with_vad(vad)
|
||||
logger.info("VAD enabled")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
raise
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio: Union[str, np.ndarray],
|
||||
sample_rate: int = 16000,
|
||||
) -> Union[str, list]:
|
||||
"""
|
||||
Transcribe audio to text.
|
||||
|
||||
Args:
|
||||
audio: Audio data as numpy array (float32) or path to WAV file
|
||||
sample_rate: Sample rate of audio (default: 16000 Hz)
|
||||
|
||||
Returns:
|
||||
Transcribed text string, or list of results if VAD is enabled
|
||||
"""
|
||||
try:
|
||||
if isinstance(audio, str):
|
||||
# Load from file
|
||||
result = self.model.recognize(audio)
|
||||
else:
|
||||
# Process numpy array
|
||||
if audio.dtype != np.float32:
|
||||
audio = audio.astype(np.float32)
|
||||
result = self.model.recognize(audio, sample_rate=sample_rate)
|
||||
|
||||
# If VAD is enabled, result is a generator
|
||||
if self.use_vad:
|
||||
return list(result)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription failed: {e}")
|
||||
raise
|
||||
|
||||
def transcribe_batch(
|
||||
self,
|
||||
audio_files: list,
|
||||
) -> list:
|
||||
"""
|
||||
Transcribe multiple audio files in batch.
|
||||
|
||||
Args:
|
||||
audio_files: List of paths to WAV files
|
||||
|
||||
Returns:
|
||||
List of transcribed text strings
|
||||
"""
|
||||
try:
|
||||
results = self.model.recognize(audio_files)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Batch transcription failed: {e}")
|
||||
raise
|
||||
|
||||
def transcribe_stream(
|
||||
self,
|
||||
audio_chunk: np.ndarray,
|
||||
sample_rate: int = 16000,
|
||||
) -> str:
|
||||
"""
|
||||
Transcribe streaming audio chunk.
|
||||
|
||||
Args:
|
||||
audio_chunk: Audio chunk as numpy array (float32)
|
||||
sample_rate: Sample rate of audio
|
||||
|
||||
Returns:
|
||||
Transcribed text for the chunk
|
||||
"""
|
||||
return self.transcribe(audio_chunk, sample_rate=sample_rate)
|
||||
|
||||
|
||||
# Convenience function for backward compatibility
|
||||
def load_pipeline(**kwargs) -> ASRPipeline:
|
||||
"""Load and return ASR pipeline with given configuration."""
|
||||
return ASRPipeline(**kwargs)
|
||||
@@ -1,6 +0,0 @@
|
||||
"""
|
||||
Client module for microphone streaming
|
||||
"""
|
||||
from .mic_stream import MicrophoneStreamClient, list_audio_devices
|
||||
|
||||
__all__ = ["MicrophoneStreamClient", "list_audio_devices"]
|
||||
@@ -1,235 +0,0 @@
|
||||
"""
|
||||
Microphone streaming client for ASR WebSocket server
|
||||
"""
|
||||
import asyncio
|
||||
import websockets
|
||||
import sounddevice as sd
|
||||
import numpy as np
|
||||
import json
|
||||
import logging
|
||||
import queue
|
||||
from typing import Optional
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MicrophoneStreamClient:
|
||||
"""
|
||||
Client for streaming microphone audio to ASR WebSocket server.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str = "ws://localhost:8766",
|
||||
sample_rate: int = 16000,
|
||||
channels: int = 1,
|
||||
chunk_duration: float = 0.1, # seconds
|
||||
device: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Initialize microphone streaming client.
|
||||
|
||||
Args:
|
||||
server_url: WebSocket server URL
|
||||
sample_rate: Audio sample rate (16000 Hz recommended)
|
||||
channels: Number of audio channels (1 for mono)
|
||||
chunk_duration: Duration of each audio chunk in seconds
|
||||
device: Optional audio input device index
|
||||
"""
|
||||
self.server_url = server_url
|
||||
self.sample_rate = sample_rate
|
||||
self.channels = channels
|
||||
self.chunk_duration = chunk_duration
|
||||
self.chunk_samples = int(sample_rate * chunk_duration)
|
||||
self.device = device
|
||||
|
||||
self.audio_queue = queue.Queue()
|
||||
self.is_recording = False
|
||||
self.websocket = None
|
||||
|
||||
logger.info(f"Microphone client initialized")
|
||||
logger.info(f"Server URL: {server_url}")
|
||||
logger.info(f"Sample rate: {sample_rate} Hz")
|
||||
logger.info(f"Chunk duration: {chunk_duration}s ({self.chunk_samples} samples)")
|
||||
|
||||
def audio_callback(self, indata, frames, time_info, status):
|
||||
"""
|
||||
Callback for sounddevice stream.
|
||||
|
||||
Args:
|
||||
indata: Input audio data
|
||||
frames: Number of frames
|
||||
time_info: Timing information
|
||||
status: Status flags
|
||||
"""
|
||||
if status:
|
||||
logger.warning(f"Audio callback status: {status}")
|
||||
|
||||
# Convert to int16 and put in queue
|
||||
audio_data = (indata[:, 0] * 32767).astype(np.int16)
|
||||
self.audio_queue.put(audio_data.tobytes())
|
||||
|
||||
async def send_audio(self):
|
||||
"""
|
||||
Coroutine to send audio from queue to WebSocket.
|
||||
"""
|
||||
while self.is_recording:
|
||||
try:
|
||||
# Get audio data from queue (non-blocking)
|
||||
audio_bytes = self.audio_queue.get_nowait()
|
||||
|
||||
if self.websocket:
|
||||
await self.websocket.send(audio_bytes)
|
||||
|
||||
except queue.Empty:
|
||||
# No audio data available, wait a bit
|
||||
await asyncio.sleep(0.01)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending audio: {e}")
|
||||
break
|
||||
|
||||
async def receive_transcripts(self):
|
||||
"""
|
||||
Coroutine to receive transcripts from WebSocket.
|
||||
"""
|
||||
while self.is_recording:
|
||||
try:
|
||||
if self.websocket:
|
||||
message = await asyncio.wait_for(
|
||||
self.websocket.recv(),
|
||||
timeout=0.1
|
||||
)
|
||||
|
||||
try:
|
||||
data = json.loads(message)
|
||||
|
||||
if data.get("type") == "transcript":
|
||||
text = data.get("text", "")
|
||||
is_final = data.get("is_final", False)
|
||||
|
||||
if is_final:
|
||||
logger.info(f"[FINAL] {text}")
|
||||
else:
|
||||
logger.info(f"[PARTIAL] {text}")
|
||||
|
||||
elif data.get("type") == "info":
|
||||
logger.info(f"Server: {data.get('message')}")
|
||||
|
||||
elif data.get("type") == "error":
|
||||
logger.error(f"Server error: {data.get('message')}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON response: {message}")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving transcript: {e}")
|
||||
break
|
||||
|
||||
async def stream_audio(self):
|
||||
"""
|
||||
Main coroutine to stream audio to server.
|
||||
"""
|
||||
try:
|
||||
async with websockets.connect(self.server_url) as websocket:
|
||||
self.websocket = websocket
|
||||
logger.info(f"Connected to server: {self.server_url}")
|
||||
|
||||
self.is_recording = True
|
||||
|
||||
# Start audio stream
|
||||
with sd.InputStream(
|
||||
samplerate=self.sample_rate,
|
||||
channels=self.channels,
|
||||
dtype=np.float32,
|
||||
blocksize=self.chunk_samples,
|
||||
device=self.device,
|
||||
callback=self.audio_callback,
|
||||
):
|
||||
logger.info("Recording started. Press Ctrl+C to stop.")
|
||||
|
||||
# Run send and receive coroutines concurrently
|
||||
await asyncio.gather(
|
||||
self.send_audio(),
|
||||
self.receive_transcripts(),
|
||||
)
|
||||
|
||||
except websockets.exceptions.WebSocketException as e:
|
||||
logger.error(f"WebSocket error: {e}")
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Stopped by user")
|
||||
finally:
|
||||
self.is_recording = False
|
||||
|
||||
# Send final command
|
||||
if self.websocket:
|
||||
try:
|
||||
await self.websocket.send(json.dumps({"type": "final"}))
|
||||
await asyncio.sleep(0.5) # Wait for final response
|
||||
except:
|
||||
pass
|
||||
|
||||
self.websocket = None
|
||||
logger.info("Disconnected from server")
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Run the client (blocking).
|
||||
"""
|
||||
try:
|
||||
asyncio.run(self.stream_audio())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Client stopped by user")
|
||||
|
||||
|
||||
def list_audio_devices():
|
||||
"""
|
||||
List available audio input devices.
|
||||
"""
|
||||
print("\nAvailable audio input devices:")
|
||||
print("-" * 80)
|
||||
devices = sd.query_devices()
|
||||
for i, device in enumerate(devices):
|
||||
if device['max_input_channels'] > 0:
|
||||
print(f"[{i}] {device['name']}")
|
||||
print(f" Channels: {device['max_input_channels']}")
|
||||
print(f" Sample rate: {device['default_samplerate']} Hz")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main entry point for the microphone client.
|
||||
"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Microphone Streaming Client")
|
||||
parser.add_argument("--url", default="ws://localhost:8766", help="WebSocket server URL")
|
||||
parser.add_argument("--sample-rate", type=int, default=16000, help="Audio sample rate")
|
||||
parser.add_argument("--device", type=int, default=None, help="Audio input device index")
|
||||
parser.add_argument("--list-devices", action="store_true", help="List audio devices and exit")
|
||||
parser.add_argument("--chunk-duration", type=float, default=0.1, help="Audio chunk duration (seconds)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.list_devices:
|
||||
list_audio_devices()
|
||||
return
|
||||
|
||||
client = MicrophoneStreamClient(
|
||||
server_url=args.url,
|
||||
sample_rate=args.sample_rate,
|
||||
device=args.device,
|
||||
chunk_duration=args.chunk_duration,
|
||||
)
|
||||
|
||||
client.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,15 +0,0 @@
|
||||
"""
|
||||
Simple example of using the ASR pipeline
|
||||
"""
|
||||
from asr.asr_pipeline import ASRPipeline
|
||||
|
||||
# Initialize pipeline (will download model on first run)
|
||||
print("Loading ASR model...")
|
||||
pipeline = ASRPipeline()
|
||||
|
||||
# Transcribe a WAV file
|
||||
print("\nTranscribing audio...")
|
||||
text = pipeline.transcribe("test.wav")
|
||||
|
||||
print("\nTranscription:")
|
||||
print(text)
|
||||
@@ -1,54 +0,0 @@
|
||||
# Parakeet ASR WebSocket Server - Strict Requirements
|
||||
# Python version: 3.11.14
|
||||
# pip version: 25.3
|
||||
#
|
||||
# Installation:
|
||||
# python3.11 -m venv venv
|
||||
# source venv/bin/activate
|
||||
# pip install --upgrade pip==25.3
|
||||
# pip install -r requirements-stt.txt
|
||||
#
|
||||
# System requirements:
|
||||
# - CUDA 12.x compatible GPU (optional, for GPU acceleration)
|
||||
# - Linux (tested on Arch Linux)
|
||||
# - ~6GB VRAM for GPU inference
|
||||
#
|
||||
# Generated: 2026-01-18
|
||||
|
||||
anyio==4.12.1
|
||||
certifi==2026.1.4
|
||||
cffi==2.0.0
|
||||
click==8.3.1
|
||||
coloredlogs==15.0.1
|
||||
filelock==3.20.3
|
||||
flatbuffers==25.12.19
|
||||
fsspec==2026.1.0
|
||||
h11==0.16.0
|
||||
hf-xet==1.2.0
|
||||
httpcore==1.0.9
|
||||
httpx==0.28.1
|
||||
huggingface_hub==1.3.2
|
||||
humanfriendly==10.0
|
||||
idna==3.11
|
||||
mpmath==1.3.0
|
||||
numpy==1.26.4
|
||||
nvidia-cublas-cu12==12.9.1.4
|
||||
nvidia-cuda-nvrtc-cu12==12.9.86
|
||||
nvidia-cuda-runtime-cu12==12.9.79
|
||||
nvidia-cudnn-cu12==9.18.0.77
|
||||
nvidia-cufft-cu12==11.4.1.4
|
||||
nvidia-nvjitlink-cu12==12.9.86
|
||||
onnx-asr==0.10.1
|
||||
onnxruntime-gpu==1.23.2
|
||||
packaging==25.0
|
||||
protobuf==6.33.4
|
||||
pycparser==2.23
|
||||
PyYAML==6.0.3
|
||||
shellingham==1.5.4
|
||||
sounddevice==0.5.3
|
||||
soundfile==0.13.1
|
||||
sympy==1.14.0
|
||||
tqdm==4.67.1
|
||||
typer-slim==0.21.1
|
||||
typing_extensions==4.15.0
|
||||
websockets==16.0
|
||||
@@ -1,12 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Wrapper script to run Python with proper environment
|
||||
|
||||
# Set up library paths for CUDA
|
||||
VENV_DIR="/home/koko210Serve/parakeet-test/venv/lib/python3.11/site-packages"
|
||||
export LD_LIBRARY_PATH="${VENV_DIR}/nvidia/cublas/lib:${VENV_DIR}/nvidia/cudnn/lib:${VENV_DIR}/nvidia/cufft/lib:${VENV_DIR}/nvidia/cuda_nvrtc/lib:${VENV_DIR}/nvidia/cuda_runtime/lib:$LD_LIBRARY_PATH"
|
||||
|
||||
# Set Python path
|
||||
export PYTHONPATH="/home/koko210Serve/parakeet-test:$PYTHONPATH"
|
||||
|
||||
# Run Python with arguments
|
||||
exec /home/koko210Serve/parakeet-test/venv/bin/python "$@"
|
||||
@@ -1,6 +0,0 @@
|
||||
"""
|
||||
WebSocket server module for streaming ASR
|
||||
"""
|
||||
from .ws_server import ASRWebSocketServer
|
||||
|
||||
__all__ = ["ASRWebSocketServer"]
|
||||
@@ -1,292 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
ASR WebSocket Server with Live Transcription Display
|
||||
|
||||
This version displays transcriptions in real-time on the server console
|
||||
while clients stream audio from remote machines.
|
||||
"""
|
||||
import asyncio
|
||||
import websockets
|
||||
import numpy as np
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from asr.asr_pipeline import ASRPipeline
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('display_server.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DisplayServer:
|
||||
"""
|
||||
WebSocket server with live transcription display.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8766,
|
||||
model_path: str = "models/parakeet",
|
||||
sample_rate: int = 16000,
|
||||
):
|
||||
"""
|
||||
Initialize server.
|
||||
|
||||
Args:
|
||||
host: Host address to bind to
|
||||
port: Port to bind to
|
||||
model_path: Directory containing model files
|
||||
sample_rate: Audio sample rate
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.sample_rate = sample_rate
|
||||
self.active_connections = set()
|
||||
|
||||
# Terminal control codes
|
||||
self.CLEAR_LINE = '\033[2K'
|
||||
self.CURSOR_UP = '\033[1A'
|
||||
self.BOLD = '\033[1m'
|
||||
self.GREEN = '\033[92m'
|
||||
self.YELLOW = '\033[93m'
|
||||
self.BLUE = '\033[94m'
|
||||
self.RESET = '\033[0m'
|
||||
|
||||
# Initialize ASR pipeline
|
||||
logger.info("Loading ASR model...")
|
||||
self.pipeline = ASRPipeline(model_path=model_path)
|
||||
logger.info("ASR Pipeline ready")
|
||||
|
||||
# Client sessions
|
||||
self.sessions = {}
|
||||
|
||||
def print_header(self):
|
||||
"""Print server header."""
|
||||
print("\n" + "=" * 80)
|
||||
print(f"{self.BOLD}{self.BLUE}ASR Server - Live Transcription Display{self.RESET}")
|
||||
print("=" * 80)
|
||||
print(f"Server: ws://{self.host}:{self.port}")
|
||||
print(f"Sample Rate: {self.sample_rate} Hz")
|
||||
print(f"Model: Parakeet TDT 0.6B V3")
|
||||
print("=" * 80 + "\n")
|
||||
|
||||
def display_transcription(self, client_id: str, text: str, is_final: bool, is_progressive: bool = False):
|
||||
"""
|
||||
Display transcription in the terminal.
|
||||
|
||||
Args:
|
||||
client_id: Client identifier
|
||||
text: Transcribed text
|
||||
is_final: Whether this is the final transcription
|
||||
is_progressive: Whether this is a progressive update
|
||||
"""
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
|
||||
if is_final:
|
||||
# Final transcription - bold green
|
||||
print(f"{self.GREEN}{self.BOLD}[{timestamp}] {client_id}{self.RESET}")
|
||||
print(f"{self.GREEN} ✓ FINAL: {text}{self.RESET}\n")
|
||||
elif is_progressive:
|
||||
# Progressive update - yellow
|
||||
print(f"{self.YELLOW}[{timestamp}] {client_id}{self.RESET}")
|
||||
print(f"{self.YELLOW} → {text}{self.RESET}\n")
|
||||
else:
|
||||
# Regular transcription
|
||||
print(f"{self.BLUE}[{timestamp}] {client_id}{self.RESET}")
|
||||
print(f" {text}\n")
|
||||
|
||||
# Flush to ensure immediate display
|
||||
sys.stdout.flush()
|
||||
|
||||
async def handle_client(self, websocket):
|
||||
"""
|
||||
Handle individual WebSocket client connection.
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
"""
|
||||
client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}"
|
||||
logger.info(f"Client connected: {client_id}")
|
||||
self.active_connections.add(websocket)
|
||||
|
||||
# Display connection
|
||||
print(f"\n{self.BOLD}{'='*80}{self.RESET}")
|
||||
print(f"{self.GREEN}✓ Client connected: {client_id}{self.RESET}")
|
||||
print(f"{self.BOLD}{'='*80}{self.RESET}\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
# Audio buffer for accumulating ALL audio
|
||||
all_audio = []
|
||||
last_transcribed_samples = 0
|
||||
|
||||
# For progressive transcription
|
||||
min_chunk_duration = 2.0 # Minimum 2 seconds before transcribing
|
||||
min_chunk_samples = int(self.sample_rate * min_chunk_duration)
|
||||
|
||||
try:
|
||||
# Send welcome message
|
||||
await websocket.send(json.dumps({
|
||||
"type": "info",
|
||||
"message": "Connected to ASR server with live display",
|
||||
"sample_rate": self.sample_rate,
|
||||
}))
|
||||
|
||||
async for message in websocket:
|
||||
try:
|
||||
if isinstance(message, bytes):
|
||||
# Binary audio data
|
||||
audio_data = np.frombuffer(message, dtype=np.int16)
|
||||
audio_data = audio_data.astype(np.float32) / 32768.0
|
||||
|
||||
# Accumulate all audio
|
||||
all_audio.append(audio_data)
|
||||
total_samples = sum(len(chunk) for chunk in all_audio)
|
||||
|
||||
# Transcribe periodically when we have enough NEW audio
|
||||
samples_since_last = total_samples - last_transcribed_samples
|
||||
if samples_since_last >= min_chunk_samples:
|
||||
audio_chunk = np.concatenate(all_audio)
|
||||
last_transcribed_samples = total_samples
|
||||
|
||||
# Transcribe the accumulated audio
|
||||
try:
|
||||
text = self.pipeline.transcribe(
|
||||
audio_chunk,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
if text and text.strip():
|
||||
# Display on server
|
||||
self.display_transcription(client_id, text, is_final=False, is_progressive=True)
|
||||
|
||||
# Send to client
|
||||
response = {
|
||||
"type": "transcript",
|
||||
"text": text,
|
||||
"is_final": False,
|
||||
}
|
||||
await websocket.send(json.dumps(response))
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription error: {e}")
|
||||
await websocket.send(json.dumps({
|
||||
"type": "error",
|
||||
"message": f"Transcription failed: {str(e)}"
|
||||
}))
|
||||
|
||||
elif isinstance(message, str):
|
||||
# JSON command
|
||||
try:
|
||||
command = json.loads(message)
|
||||
|
||||
if command.get("type") == "final":
|
||||
# Process all accumulated audio (final transcription)
|
||||
if all_audio:
|
||||
audio_chunk = np.concatenate(all_audio)
|
||||
|
||||
text = self.pipeline.transcribe(
|
||||
audio_chunk,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
if text and text.strip():
|
||||
# Display on server
|
||||
self.display_transcription(client_id, text, is_final=True)
|
||||
|
||||
# Send to client
|
||||
response = {
|
||||
"type": "transcript",
|
||||
"text": text,
|
||||
"is_final": True,
|
||||
}
|
||||
await websocket.send(json.dumps(response))
|
||||
|
||||
# Clear buffer after final transcription
|
||||
all_audio = []
|
||||
last_transcribed_samples = 0
|
||||
|
||||
elif command.get("type") == "reset":
|
||||
# Reset buffer
|
||||
all_audio = []
|
||||
last_transcribed_samples = 0
|
||||
await websocket.send(json.dumps({
|
||||
"type": "info",
|
||||
"message": "Buffer reset"
|
||||
}))
|
||||
print(f"{self.YELLOW}[{client_id}] Buffer reset{self.RESET}\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON from {client_id}: {message}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message from {client_id}: {e}")
|
||||
break
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info(f"Connection closed: {client_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error with {client_id}: {e}")
|
||||
finally:
|
||||
self.active_connections.discard(websocket)
|
||||
print(f"\n{self.BOLD}{'='*80}{self.RESET}")
|
||||
print(f"{self.YELLOW}✗ Client disconnected: {client_id}{self.RESET}")
|
||||
print(f"{self.BOLD}{'='*80}{self.RESET}\n")
|
||||
sys.stdout.flush()
|
||||
logger.info(f"Connection closed: {client_id}")
|
||||
|
||||
async def start(self):
|
||||
"""Start the WebSocket server."""
|
||||
self.print_header()
|
||||
|
||||
async with websockets.serve(self.handle_client, self.host, self.port):
|
||||
logger.info(f"Starting WebSocket server on {self.host}:{self.port}")
|
||||
print(f"{self.GREEN}{self.BOLD}Server is running and ready for connections!{self.RESET}")
|
||||
print(f"{self.BOLD}Waiting for clients...{self.RESET}\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
# Keep server running
|
||||
await asyncio.Future()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="ASR Server with Live Display")
|
||||
parser.add_argument("--host", default="0.0.0.0", help="Host address")
|
||||
parser.add_argument("--port", type=int, default=8766, help="Port number")
|
||||
parser.add_argument("--model-path", default="models/parakeet", help="Model directory")
|
||||
parser.add_argument("--sample-rate", type=int, default=16000, help="Sample rate")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
server = DisplayServer(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
model_path=args.model_path,
|
||||
sample_rate=args.sample_rate,
|
||||
)
|
||||
|
||||
try:
|
||||
asyncio.run(server.start())
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n\n{server.YELLOW}Server stopped by user{server.RESET}")
|
||||
logger.info("Server stopped by user")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,416 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
ASR WebSocket Server with VAD - Optimized for Discord Bots
|
||||
|
||||
This server uses Voice Activity Detection (VAD) to:
|
||||
- Detect speech start and end automatically
|
||||
- Only transcribe speech segments (ignore silence)
|
||||
- Provide clean boundaries for Discord message formatting
|
||||
- Minimize processing of silence/noise
|
||||
"""
|
||||
import asyncio
|
||||
import websockets
|
||||
import numpy as np
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from asr.asr_pipeline import ASRPipeline
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('vad_server.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeechSegment:
|
||||
"""Represents a segment of detected speech."""
|
||||
audio: np.ndarray
|
||||
start_time: float
|
||||
end_time: Optional[float] = None
|
||||
is_complete: bool = False
|
||||
|
||||
|
||||
class VADState:
|
||||
"""Manages VAD state for speech detection."""
|
||||
|
||||
def __init__(self, sample_rate: int = 16000, speech_threshold: float = 0.5):
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
# Simple energy-based VAD parameters
|
||||
self.energy_threshold = 0.005 # Lower threshold for better detection
|
||||
self.speech_frames = 0
|
||||
self.silence_frames = 0
|
||||
self.min_speech_frames = 3 # 3 frames minimum (300ms with 100ms chunks)
|
||||
self.min_silence_frames = 5 # 5 frames of silence (500ms)
|
||||
|
||||
self.is_speech = False
|
||||
self.speech_buffer = []
|
||||
|
||||
# Pre-buffer to capture audio BEFORE speech detection
|
||||
# This prevents cutting off the start of speech
|
||||
self.pre_buffer_frames = 5 # Keep 5 frames (500ms) of pre-speech audio
|
||||
self.pre_buffer = deque(maxlen=self.pre_buffer_frames)
|
||||
|
||||
# Progressive transcription tracking
|
||||
self.last_partial_samples = 0 # Track when we last sent a partial
|
||||
self.partial_interval_samples = int(sample_rate * 0.3) # Partial every 0.3 seconds (near real-time)
|
||||
|
||||
logger.info(f"VAD initialized: energy_threshold={self.energy_threshold}, pre_buffer={self.pre_buffer_frames} frames")
|
||||
|
||||
def calculate_energy(self, audio_chunk: np.ndarray) -> float:
|
||||
"""Calculate RMS energy of audio chunk."""
|
||||
return np.sqrt(np.mean(audio_chunk ** 2))
|
||||
|
||||
def process_audio(self, audio_chunk: np.ndarray) -> tuple[bool, Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""
|
||||
Process audio chunk and detect speech boundaries.
|
||||
|
||||
Returns:
|
||||
(speech_detected, complete_segment, partial_segment)
|
||||
- speech_detected: True if currently in speech
|
||||
- complete_segment: Audio segment if speech ended, None otherwise
|
||||
- partial_segment: Audio for partial transcription, None otherwise
|
||||
"""
|
||||
energy = self.calculate_energy(audio_chunk)
|
||||
chunk_is_speech = energy > self.energy_threshold
|
||||
|
||||
logger.debug(f"Energy: {energy:.6f}, Is speech: {chunk_is_speech}")
|
||||
|
||||
partial_segment = None
|
||||
|
||||
if chunk_is_speech:
|
||||
self.speech_frames += 1
|
||||
self.silence_frames = 0
|
||||
|
||||
if not self.is_speech and self.speech_frames >= self.min_speech_frames:
|
||||
# Speech started - add pre-buffer to capture the beginning!
|
||||
self.is_speech = True
|
||||
logger.info("🎤 Speech started (including pre-buffer)")
|
||||
|
||||
# Add pre-buffered audio to speech buffer
|
||||
if self.pre_buffer:
|
||||
logger.debug(f"Adding {len(self.pre_buffer)} pre-buffered frames")
|
||||
self.speech_buffer.extend(list(self.pre_buffer))
|
||||
self.pre_buffer.clear()
|
||||
|
||||
if self.is_speech:
|
||||
self.speech_buffer.append(audio_chunk)
|
||||
else:
|
||||
# Not in speech yet, keep in pre-buffer
|
||||
self.pre_buffer.append(audio_chunk)
|
||||
|
||||
# Check if we should send a partial transcription
|
||||
current_samples = sum(len(chunk) for chunk in self.speech_buffer)
|
||||
samples_since_last_partial = current_samples - self.last_partial_samples
|
||||
|
||||
# Send partial if enough NEW audio accumulated AND we have minimum duration
|
||||
min_duration_for_partial = int(self.sample_rate * 0.8) # At least 0.8s of audio
|
||||
if samples_since_last_partial >= self.partial_interval_samples and current_samples >= min_duration_for_partial:
|
||||
# Time for a partial update
|
||||
partial_segment = np.concatenate(self.speech_buffer)
|
||||
self.last_partial_samples = current_samples
|
||||
logger.debug(f"📝 Partial update: {current_samples/self.sample_rate:.2f}s")
|
||||
else:
|
||||
if self.is_speech:
|
||||
self.silence_frames += 1
|
||||
|
||||
# Add some trailing silence (up to limit)
|
||||
if self.silence_frames < self.min_silence_frames:
|
||||
self.speech_buffer.append(audio_chunk)
|
||||
else:
|
||||
# Speech ended
|
||||
logger.info(f"🛑 Speech ended after {self.silence_frames} silence frames")
|
||||
self.is_speech = False
|
||||
self.speech_frames = 0
|
||||
self.silence_frames = 0
|
||||
self.last_partial_samples = 0 # Reset partial counter
|
||||
|
||||
if self.speech_buffer:
|
||||
complete_segment = np.concatenate(self.speech_buffer)
|
||||
segment_duration = len(complete_segment) / self.sample_rate
|
||||
self.speech_buffer = []
|
||||
self.pre_buffer.clear() # Clear pre-buffer after speech ends
|
||||
logger.info(f"✅ Complete segment: {segment_duration:.2f}s")
|
||||
return False, complete_segment, None
|
||||
else:
|
||||
self.speech_frames = 0
|
||||
# Keep adding to pre-buffer when not in speech
|
||||
self.pre_buffer.append(audio_chunk)
|
||||
|
||||
return self.is_speech, None, partial_segment
|
||||
|
||||
|
||||
class VADServer:
|
||||
"""
|
||||
WebSocket server with VAD for Discord bot integration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8766,
|
||||
model_path: str = "models/parakeet",
|
||||
sample_rate: int = 16000,
|
||||
):
|
||||
"""Initialize server."""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.sample_rate = sample_rate
|
||||
self.active_connections = set()
|
||||
|
||||
# Terminal control codes
|
||||
self.BOLD = '\033[1m'
|
||||
self.GREEN = '\033[92m'
|
||||
self.YELLOW = '\033[93m'
|
||||
self.BLUE = '\033[94m'
|
||||
self.RED = '\033[91m'
|
||||
self.RESET = '\033[0m'
|
||||
|
||||
# Initialize ASR pipeline
|
||||
logger.info("Loading ASR model...")
|
||||
self.pipeline = ASRPipeline(model_path=model_path)
|
||||
logger.info("ASR Pipeline ready")
|
||||
|
||||
def print_header(self):
|
||||
"""Print server header."""
|
||||
print("\n" + "=" * 80)
|
||||
print(f"{self.BOLD}{self.BLUE}ASR Server with VAD - Discord Bot Ready{self.RESET}")
|
||||
print("=" * 80)
|
||||
print(f"Server: ws://{self.host}:{self.port}")
|
||||
print(f"Sample Rate: {self.sample_rate} Hz")
|
||||
print(f"Model: Parakeet TDT 0.6B V3")
|
||||
print(f"VAD: Energy-based speech detection")
|
||||
print("=" * 80 + "\n")
|
||||
|
||||
def display_transcription(self, client_id: str, text: str, duration: float):
|
||||
"""Display transcription in the terminal."""
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
print(f"{self.GREEN}{self.BOLD}[{timestamp}] {client_id}{self.RESET}")
|
||||
print(f"{self.GREEN} 📝 {text}{self.RESET}")
|
||||
print(f"{self.BLUE} ⏱️ Duration: {duration:.2f}s{self.RESET}\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
async def handle_client(self, websocket):
|
||||
"""Handle WebSocket client connection."""
|
||||
client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}"
|
||||
logger.info(f"Client connected: {client_id}")
|
||||
self.active_connections.add(websocket)
|
||||
|
||||
print(f"\n{self.BOLD}{'='*80}{self.RESET}")
|
||||
print(f"{self.GREEN}✓ Client connected: {client_id}{self.RESET}")
|
||||
print(f"{self.BOLD}{'='*80}{self.RESET}\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
# Initialize VAD state for this client
|
||||
vad_state = VADState(sample_rate=self.sample_rate)
|
||||
|
||||
try:
|
||||
# Send welcome message
|
||||
await websocket.send(json.dumps({
|
||||
"type": "info",
|
||||
"message": "Connected to ASR server with VAD",
|
||||
"sample_rate": self.sample_rate,
|
||||
"vad_enabled": True,
|
||||
}))
|
||||
|
||||
async for message in websocket:
|
||||
try:
|
||||
if isinstance(message, bytes):
|
||||
# Binary audio data
|
||||
audio_data = np.frombuffer(message, dtype=np.int16)
|
||||
audio_data = audio_data.astype(np.float32) / 32768.0
|
||||
|
||||
# Process through VAD
|
||||
is_speech, complete_segment, partial_segment = vad_state.process_audio(audio_data)
|
||||
|
||||
# Send VAD status to client (only on state change)
|
||||
prev_speech_state = getattr(vad_state, '_prev_speech_state', False)
|
||||
if is_speech != prev_speech_state:
|
||||
vad_state._prev_speech_state = is_speech
|
||||
await websocket.send(json.dumps({
|
||||
"type": "vad_status",
|
||||
"is_speech": is_speech,
|
||||
}))
|
||||
|
||||
# Handle partial transcription (progressive updates while speaking)
|
||||
if partial_segment is not None:
|
||||
try:
|
||||
text = self.pipeline.transcribe(
|
||||
partial_segment,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
if text and text.strip():
|
||||
duration = len(partial_segment) / self.sample_rate
|
||||
|
||||
# Display on server
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
print(f"{self.YELLOW}[{timestamp}] {client_id}{self.RESET}")
|
||||
print(f"{self.YELLOW} → PARTIAL: {text}{self.RESET}\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
# Send to client
|
||||
response = {
|
||||
"type": "transcript",
|
||||
"text": text,
|
||||
"is_final": False,
|
||||
"duration": duration,
|
||||
}
|
||||
await websocket.send(json.dumps(response))
|
||||
except Exception as e:
|
||||
logger.error(f"Partial transcription error: {e}")
|
||||
|
||||
# If we have a complete speech segment, transcribe it
|
||||
if complete_segment is not None:
|
||||
try:
|
||||
text = self.pipeline.transcribe(
|
||||
complete_segment,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
if text and text.strip():
|
||||
duration = len(complete_segment) / self.sample_rate
|
||||
|
||||
# Display on server
|
||||
self.display_transcription(client_id, text, duration)
|
||||
|
||||
# Send to client
|
||||
response = {
|
||||
"type": "transcript",
|
||||
"text": text,
|
||||
"is_final": True,
|
||||
"duration": duration,
|
||||
}
|
||||
await websocket.send(json.dumps(response))
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription error: {e}")
|
||||
await websocket.send(json.dumps({
|
||||
"type": "error",
|
||||
"message": f"Transcription failed: {str(e)}"
|
||||
}))
|
||||
|
||||
elif isinstance(message, str):
|
||||
# JSON command
|
||||
try:
|
||||
command = json.loads(message)
|
||||
|
||||
if command.get("type") == "force_transcribe":
|
||||
# Force transcribe current buffer
|
||||
if vad_state.speech_buffer:
|
||||
audio_chunk = np.concatenate(vad_state.speech_buffer)
|
||||
vad_state.speech_buffer = []
|
||||
vad_state.is_speech = False
|
||||
|
||||
text = self.pipeline.transcribe(
|
||||
audio_chunk,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
if text and text.strip():
|
||||
duration = len(audio_chunk) / self.sample_rate
|
||||
self.display_transcription(client_id, text, duration)
|
||||
|
||||
response = {
|
||||
"type": "transcript",
|
||||
"text": text,
|
||||
"is_final": True,
|
||||
"duration": duration,
|
||||
}
|
||||
await websocket.send(json.dumps(response))
|
||||
|
||||
elif command.get("type") == "reset":
|
||||
# Reset VAD state
|
||||
vad_state = VADState(sample_rate=self.sample_rate)
|
||||
await websocket.send(json.dumps({
|
||||
"type": "info",
|
||||
"message": "VAD state reset"
|
||||
}))
|
||||
print(f"{self.YELLOW}[{client_id}] VAD reset{self.RESET}\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
elif command.get("type") == "set_threshold":
|
||||
# Adjust VAD threshold
|
||||
threshold = command.get("threshold", 0.01)
|
||||
vad_state.energy_threshold = threshold
|
||||
await websocket.send(json.dumps({
|
||||
"type": "info",
|
||||
"message": f"VAD threshold set to {threshold}"
|
||||
}))
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON from {client_id}: {message}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message from {client_id}: {e}")
|
||||
break
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info(f"Connection closed: {client_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error with {client_id}: {e}")
|
||||
finally:
|
||||
self.active_connections.discard(websocket)
|
||||
print(f"\n{self.BOLD}{'='*80}{self.RESET}")
|
||||
print(f"{self.YELLOW}✗ Client disconnected: {client_id}{self.RESET}")
|
||||
print(f"{self.BOLD}{'='*80}{self.RESET}\n")
|
||||
sys.stdout.flush()
|
||||
logger.info(f"Connection closed: {client_id}")
|
||||
|
||||
async def start(self):
|
||||
"""Start the WebSocket server."""
|
||||
self.print_header()
|
||||
|
||||
async with websockets.serve(self.handle_client, self.host, self.port):
|
||||
logger.info(f"Starting WebSocket server on {self.host}:{self.port}")
|
||||
print(f"{self.GREEN}{self.BOLD}Server is running with VAD enabled!{self.RESET}")
|
||||
print(f"{self.BOLD}Ready for Discord bot connections...{self.RESET}\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
# Keep server running
|
||||
await asyncio.Future()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="ASR Server with VAD for Discord")
|
||||
parser.add_argument("--host", default="0.0.0.0", help="Host address")
|
||||
parser.add_argument("--port", type=int, default=8766, help="Port number")
|
||||
parser.add_argument("--model-path", default="models/parakeet", help="Model directory")
|
||||
parser.add_argument("--sample-rate", type=int, default=16000, help="Sample rate")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
server = VADServer(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
model_path=args.model_path,
|
||||
sample_rate=args.sample_rate,
|
||||
)
|
||||
|
||||
try:
|
||||
asyncio.run(server.start())
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n\n{server.YELLOW}Server stopped by user{server.RESET}")
|
||||
logger.info("Server stopped by user")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,231 +0,0 @@
|
||||
"""
|
||||
WebSocket server for streaming ASR using onnx-asr
|
||||
"""
|
||||
import asyncio
|
||||
import websockets
|
||||
import numpy as np
|
||||
import json
|
||||
import logging
|
||||
from asr.asr_pipeline import ASRPipeline
|
||||
from typing import Optional
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ASRWebSocketServer:
|
||||
"""
|
||||
WebSocket server for real-time speech recognition.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8766,
|
||||
model_name: str = "nemo-parakeet-tdt-0.6b-v3",
|
||||
model_path: Optional[str] = None,
|
||||
use_vad: bool = False,
|
||||
sample_rate: int = 16000,
|
||||
):
|
||||
"""
|
||||
Initialize WebSocket server.
|
||||
|
||||
Args:
|
||||
host: Server host address
|
||||
port: Server port
|
||||
model_name: ASR model name
|
||||
model_path: Optional local model path
|
||||
use_vad: Whether to use VAD
|
||||
sample_rate: Expected audio sample rate
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
logger.info("Initializing ASR Pipeline...")
|
||||
self.pipeline = ASRPipeline(
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
use_vad=use_vad,
|
||||
)
|
||||
logger.info("ASR Pipeline ready")
|
||||
|
||||
self.active_connections = set()
|
||||
|
||||
async def handle_client(self, websocket):
|
||||
"""
|
||||
Handle individual WebSocket client connection.
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
"""
|
||||
client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}"
|
||||
logger.info(f"Client connected: {client_id}")
|
||||
self.active_connections.add(websocket)
|
||||
|
||||
# Audio buffer for accumulating ALL audio
|
||||
all_audio = []
|
||||
last_transcribed_samples = 0 # Track what we've already transcribed
|
||||
|
||||
# For progressive transcription, we'll accumulate and transcribe the full buffer
|
||||
# This gives better results than processing tiny chunks
|
||||
min_chunk_duration = 2.0 # Minimum 2 seconds before transcribing
|
||||
min_chunk_samples = int(self.sample_rate * min_chunk_duration)
|
||||
|
||||
try:
|
||||
# Send welcome message
|
||||
await websocket.send(json.dumps({
|
||||
"type": "info",
|
||||
"message": "Connected to ASR server",
|
||||
"sample_rate": self.sample_rate,
|
||||
}))
|
||||
|
||||
async for message in websocket:
|
||||
try:
|
||||
if isinstance(message, bytes):
|
||||
# Binary audio data
|
||||
# Convert bytes to float32 numpy array
|
||||
# Assuming int16 PCM data
|
||||
audio_data = np.frombuffer(message, dtype=np.int16)
|
||||
audio_data = audio_data.astype(np.float32) / 32768.0
|
||||
|
||||
# Accumulate all audio
|
||||
all_audio.append(audio_data)
|
||||
total_samples = sum(len(chunk) for chunk in all_audio)
|
||||
|
||||
# Transcribe periodically when we have enough NEW audio
|
||||
samples_since_last = total_samples - last_transcribed_samples
|
||||
if samples_since_last >= min_chunk_samples:
|
||||
audio_chunk = np.concatenate(all_audio)
|
||||
last_transcribed_samples = total_samples
|
||||
|
||||
# Transcribe the accumulated audio
|
||||
try:
|
||||
text = self.pipeline.transcribe(
|
||||
audio_chunk,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
if text and text.strip():
|
||||
response = {
|
||||
"type": "transcript",
|
||||
"text": text,
|
||||
"is_final": False,
|
||||
}
|
||||
await websocket.send(json.dumps(response))
|
||||
logger.info(f"Progressive transcription: {text}")
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription error: {e}")
|
||||
await websocket.send(json.dumps({
|
||||
"type": "error",
|
||||
"message": f"Transcription failed: {str(e)}"
|
||||
}))
|
||||
|
||||
elif isinstance(message, str):
|
||||
# JSON command
|
||||
try:
|
||||
command = json.loads(message)
|
||||
|
||||
if command.get("type") == "final":
|
||||
# Process all accumulated audio (final transcription)
|
||||
if all_audio:
|
||||
audio_chunk = np.concatenate(all_audio)
|
||||
|
||||
text = self.pipeline.transcribe(
|
||||
audio_chunk,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
if text and text.strip():
|
||||
response = {
|
||||
"type": "transcript",
|
||||
"text": text,
|
||||
"is_final": True,
|
||||
}
|
||||
await websocket.send(json.dumps(response))
|
||||
logger.info(f"Final transcription: {text}")
|
||||
|
||||
# Clear buffer after final transcription
|
||||
all_audio = []
|
||||
last_transcribed_samples = 0
|
||||
|
||||
elif command.get("type") == "reset":
|
||||
# Reset buffer
|
||||
all_audio = []
|
||||
last_transcribed_samples = 0
|
||||
await websocket.send(json.dumps({
|
||||
"type": "info",
|
||||
"message": "Buffer reset"
|
||||
}))
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON command: {message}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
await websocket.send(json.dumps({
|
||||
"type": "error",
|
||||
"message": str(e)
|
||||
}))
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info(f"Client disconnected: {client_id}")
|
||||
|
||||
finally:
|
||||
self.active_connections.discard(websocket)
|
||||
logger.info(f"Connection closed: {client_id}")
|
||||
|
||||
async def start(self):
|
||||
"""
|
||||
Start the WebSocket server.
|
||||
"""
|
||||
logger.info(f"Starting WebSocket server on {self.host}:{self.port}")
|
||||
|
||||
async with websockets.serve(self.handle_client, self.host, self.port):
|
||||
logger.info(f"Server running on ws://{self.host}:{self.port}")
|
||||
logger.info(f"Active connections: {len(self.active_connections)}")
|
||||
await asyncio.Future() # Run forever
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Run the server (blocking).
|
||||
"""
|
||||
try:
|
||||
asyncio.run(self.start())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Server stopped by user")
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main entry point for the WebSocket server.
|
||||
"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="ASR WebSocket Server")
|
||||
parser.add_argument("--host", default="0.0.0.0", help="Server host")
|
||||
parser.add_argument("--port", type=int, default=8766, help="Server port")
|
||||
parser.add_argument("--model", default="nemo-parakeet-tdt-0.6b-v3", help="Model name")
|
||||
parser.add_argument("--model-path", default=None, help="Local model path")
|
||||
parser.add_argument("--use-vad", action="store_true", help="Enable VAD")
|
||||
parser.add_argument("--sample-rate", type=int, default=16000, help="Audio sample rate")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
server = ASRWebSocketServer(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
model_name=args.model,
|
||||
model_path=args.model_path,
|
||||
use_vad=args.use_vad,
|
||||
sample_rate=args.sample_rate,
|
||||
)
|
||||
|
||||
server.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,181 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Setup environment for Parakeet ASR with ONNX Runtime
|
||||
|
||||
set -e
|
||||
|
||||
echo "=========================================="
|
||||
echo "Parakeet ASR Setup with onnx-asr"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Detect best Python version (3.10-3.12 for GPU support)
|
||||
echo "Detecting Python version..."
|
||||
PYTHON_CMD=""
|
||||
|
||||
for py_ver in python3.12 python3.11 python3.10; do
|
||||
if command -v $py_ver &> /dev/null; then
|
||||
PYTHON_CMD=$py_ver
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
if [ -z "$PYTHON_CMD" ]; then
|
||||
# Fallback to default python3
|
||||
PYTHON_CMD=python3
|
||||
fi
|
||||
|
||||
PYTHON_VERSION=$($PYTHON_CMD --version 2>&1 | awk '{print $2}')
|
||||
echo "Using Python: $PYTHON_CMD ($PYTHON_VERSION)"
|
||||
|
||||
# Check if virtual environment exists
|
||||
if [ ! -d "venv" ]; then
|
||||
echo ""
|
||||
echo "Creating virtual environment with $PYTHON_CMD..."
|
||||
$PYTHON_CMD -m venv venv
|
||||
echo -e "${GREEN}✓ Virtual environment created${NC}"
|
||||
else
|
||||
echo -e "${YELLOW}Virtual environment already exists${NC}"
|
||||
fi
|
||||
|
||||
# Activate virtual environment
|
||||
echo ""
|
||||
echo "Activating virtual environment..."
|
||||
source venv/bin/activate
|
||||
|
||||
# Upgrade pip
|
||||
echo ""
|
||||
echo "Upgrading pip..."
|
||||
pip install --upgrade pip
|
||||
|
||||
# Check CUDA
|
||||
echo ""
|
||||
echo "Checking CUDA installation..."
|
||||
if command -v nvcc &> /dev/null; then
|
||||
CUDA_VERSION=$(nvcc --version | grep "release" | awk '{print $5}' | cut -c2-)
|
||||
echo -e "${GREEN}✓ CUDA found: $CUDA_VERSION${NC}"
|
||||
else
|
||||
echo -e "${YELLOW}⚠ CUDA compiler (nvcc) not found${NC}"
|
||||
echo " If you have a GPU, make sure CUDA is installed:"
|
||||
echo " https://developer.nvidia.com/cuda-downloads"
|
||||
fi
|
||||
|
||||
# Check NVIDIA GPU
|
||||
echo ""
|
||||
echo "Checking NVIDIA GPU..."
|
||||
if command -v nvidia-smi &> /dev/null; then
|
||||
echo -e "${GREEN}✓ NVIDIA GPU detected${NC}"
|
||||
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader | while read line; do
|
||||
echo " $line"
|
||||
done
|
||||
else
|
||||
echo -e "${YELLOW}⚠ nvidia-smi not found${NC}"
|
||||
echo " Make sure NVIDIA drivers are installed if you have a GPU"
|
||||
fi
|
||||
|
||||
# Install dependencies
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Installing Python dependencies..."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Check Python version for GPU support
|
||||
PYTHON_MAJOR=$(python3 -c 'import sys; print(sys.version_info.major)')
|
||||
PYTHON_MINOR=$(python3 -c 'import sys; print(sys.version_info.minor)')
|
||||
|
||||
if [ "$PYTHON_MAJOR" -eq 3 ] && [ "$PYTHON_MINOR" -ge 13 ]; then
|
||||
echo -e "${YELLOW}⚠ Python 3.13+ detected${NC}"
|
||||
echo " onnxruntime-gpu is not yet available for Python 3.13+"
|
||||
echo " Installing CPU version of onnxruntime..."
|
||||
echo " For GPU support, please use Python 3.10-3.12"
|
||||
USE_GPU=false
|
||||
else
|
||||
echo "Python version supports GPU acceleration"
|
||||
USE_GPU=true
|
||||
fi
|
||||
|
||||
# Install onnx-asr
|
||||
echo ""
|
||||
if [ "$USE_GPU" = true ]; then
|
||||
echo "Installing onnx-asr with GPU support..."
|
||||
pip install "onnx-asr[gpu,hub]"
|
||||
else
|
||||
echo "Installing onnx-asr (CPU version)..."
|
||||
pip install "onnx-asr[hub]" onnxruntime
|
||||
fi
|
||||
|
||||
# Install other dependencies
|
||||
echo ""
|
||||
echo "Installing additional dependencies..."
|
||||
pip install numpy\<2.0 websockets sounddevice soundfile
|
||||
|
||||
# Optional: Install TensorRT (if available)
|
||||
echo ""
|
||||
read -p "Do you want to install TensorRT for faster inference? (y/n) " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
||||
echo "Installing TensorRT..."
|
||||
pip install tensorrt tensorrt-cu12-libs || echo -e "${YELLOW}⚠ TensorRT installation failed (optional)${NC}"
|
||||
fi
|
||||
|
||||
# Run diagnostics
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Running system diagnostics..."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
python3 tools/diagnose.py
|
||||
|
||||
# Test model download (optional)
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Model Download"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "The Parakeet model (~600MB) will be downloaded on first use."
|
||||
read -p "Do you want to download the model now? (y/n) " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
||||
echo ""
|
||||
echo "Downloading model..."
|
||||
python3 -c "
|
||||
import onnx_asr
|
||||
print('Loading model (this will download ~600MB)...')
|
||||
model = onnx_asr.load_model('nemo-parakeet-tdt-0.6b-v3', 'models/parakeet')
|
||||
print('✓ Model downloaded successfully!')
|
||||
"
|
||||
else
|
||||
echo "Model will be downloaded when you first run the ASR pipeline."
|
||||
fi
|
||||
|
||||
# Create test audio directory
|
||||
mkdir -p test_audio
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Setup Complete!"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo -e "${GREEN}✓ Environment setup successful!${NC}"
|
||||
echo ""
|
||||
echo "Next steps:"
|
||||
echo " 1. Activate the virtual environment:"
|
||||
echo " source venv/bin/activate"
|
||||
echo ""
|
||||
echo " 2. Test offline transcription:"
|
||||
echo " python3 tools/test_offline.py your_audio.wav"
|
||||
echo ""
|
||||
echo " 3. Start the WebSocket server:"
|
||||
echo " python3 server/ws_server.py"
|
||||
echo ""
|
||||
echo " 4. In another terminal, start the microphone client:"
|
||||
echo " python3 client/mic_stream.py"
|
||||
echo ""
|
||||
echo "For more information, see README.md"
|
||||
echo ""
|
||||
@@ -1,56 +0,0 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Start ASR Display Server with GPU support
|
||||
# This script sets up the environment properly for CUDA libraries
|
||||
#
|
||||
|
||||
# Get the directory where this script is located
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
# Activate virtual environment
|
||||
if [ -f "venv/bin/activate" ]; then
|
||||
source venv/bin/activate
|
||||
else
|
||||
echo "Error: Virtual environment not found at venv/bin/activate"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Get CUDA library paths from venv
|
||||
VENV_DIR="$SCRIPT_DIR/venv"
|
||||
CUDA_LIB_PATHS=(
|
||||
"$VENV_DIR/lib/python*/site-packages/nvidia/cublas/lib"
|
||||
"$VENV_DIR/lib/python*/site-packages/nvidia/cudnn/lib"
|
||||
"$VENV_DIR/lib/python*/site-packages/nvidia/cufft/lib"
|
||||
"$VENV_DIR/lib/python*/site-packages/nvidia/cuda_nvrtc/lib"
|
||||
"$VENV_DIR/lib/python*/site-packages/nvidia/cuda_runtime/lib"
|
||||
)
|
||||
|
||||
# Build LD_LIBRARY_PATH
|
||||
CUDA_LD_PATH=""
|
||||
for pattern in "${CUDA_LIB_PATHS[@]}"; do
|
||||
for path in $pattern; do
|
||||
if [ -d "$path" ]; then
|
||||
if [ -z "$CUDA_LD_PATH" ]; then
|
||||
CUDA_LD_PATH="$path"
|
||||
else
|
||||
CUDA_LD_PATH="$CUDA_LD_PATH:$path"
|
||||
fi
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
# Export library path
|
||||
if [ -n "$CUDA_LD_PATH" ]; then
|
||||
export LD_LIBRARY_PATH="$CUDA_LD_PATH:${LD_LIBRARY_PATH:-}"
|
||||
echo "CUDA libraries path set: $CUDA_LD_PATH"
|
||||
else
|
||||
echo "Warning: No CUDA libraries found in venv"
|
||||
fi
|
||||
|
||||
# Set Python path
|
||||
export PYTHONPATH="$SCRIPT_DIR:${PYTHONPATH:-}"
|
||||
|
||||
# Run the display server
|
||||
echo "Starting ASR Display Server with GPU support..."
|
||||
python server/display_server.py "$@"
|
||||
@@ -1,88 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple WebSocket client to test the ASR server
|
||||
Sends a test audio file to the server
|
||||
"""
|
||||
import asyncio
|
||||
import websockets
|
||||
import json
|
||||
import sys
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
|
||||
|
||||
async def test_connection(audio_file="test.wav"):
|
||||
"""Test connection to ASR server."""
|
||||
uri = "ws://localhost:8766"
|
||||
|
||||
print(f"Connecting to {uri}...")
|
||||
|
||||
try:
|
||||
async with websockets.connect(uri) as websocket:
|
||||
print("Connected!")
|
||||
|
||||
# Receive welcome message
|
||||
message = await websocket.recv()
|
||||
data = json.loads(message)
|
||||
print(f"Server: {data}")
|
||||
|
||||
# Load audio file
|
||||
print(f"\nLoading audio file: {audio_file}")
|
||||
audio, sr = sf.read(audio_file, dtype='float32')
|
||||
|
||||
if audio.ndim > 1:
|
||||
audio = audio[:, 0] # Convert to mono
|
||||
|
||||
print(f"Sample rate: {sr} Hz")
|
||||
print(f"Duration: {len(audio)/sr:.2f} seconds")
|
||||
|
||||
# Convert to int16 for sending
|
||||
audio_int16 = (audio * 32767).astype(np.int16)
|
||||
|
||||
# Send audio in chunks
|
||||
chunk_size = int(sr * 0.5) # 0.5 second chunks
|
||||
|
||||
print("\nSending audio...")
|
||||
|
||||
# Send all audio chunks
|
||||
for i in range(0, len(audio_int16), chunk_size):
|
||||
chunk = audio_int16[i:i+chunk_size]
|
||||
await websocket.send(chunk.tobytes())
|
||||
print(f"Sent chunk {i//chunk_size + 1}", end='\r')
|
||||
|
||||
print("\nAll chunks sent. Sending final command...")
|
||||
|
||||
# Send final command
|
||||
await websocket.send(json.dumps({"type": "final"}))
|
||||
|
||||
# Now receive ALL responses
|
||||
print("\nWaiting for transcriptions...\n")
|
||||
timeout_count = 0
|
||||
while timeout_count < 3: # Wait for 3 timeouts (6 seconds total) before giving up
|
||||
try:
|
||||
response = await asyncio.wait_for(websocket.recv(), timeout=2.0)
|
||||
result = json.loads(response)
|
||||
if result.get('type') == 'transcript':
|
||||
text = result.get('text', '')
|
||||
is_final = result.get('is_final', False)
|
||||
prefix = "→ FINAL:" if is_final else "→ Progressive:"
|
||||
print(f"{prefix} {text}\n")
|
||||
timeout_count = 0 # Reset timeout counter when we get a message
|
||||
if is_final:
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
timeout_count += 1
|
||||
|
||||
print("\nTest completed!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
audio_file = sys.argv[1] if len(sys.argv) > 1 else "test.wav"
|
||||
exit_code = asyncio.run(test_connection(audio_file))
|
||||
sys.exit(exit_code)
|
||||
@@ -1,125 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test client for VAD-enabled server
|
||||
Simulates Discord bot audio streaming with speech detection
|
||||
"""
|
||||
import asyncio
|
||||
import websockets
|
||||
import json
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import sys
|
||||
|
||||
|
||||
async def test_vad_server(audio_file="test.wav"):
|
||||
"""Test VAD server with audio file."""
|
||||
uri = "ws://localhost:8766"
|
||||
|
||||
print(f"Connecting to {uri}...")
|
||||
|
||||
try:
|
||||
async with websockets.connect(uri) as websocket:
|
||||
print("✓ Connected!\n")
|
||||
|
||||
# Receive welcome message
|
||||
message = await websocket.recv()
|
||||
data = json.loads(message)
|
||||
print(f"Server says: {data.get('message')}")
|
||||
print(f"VAD enabled: {data.get('vad_enabled')}\n")
|
||||
|
||||
# Load audio file
|
||||
print(f"Loading audio: {audio_file}")
|
||||
audio, sr = sf.read(audio_file, dtype='float32')
|
||||
|
||||
if audio.ndim > 1:
|
||||
audio = audio[:, 0] # Mono
|
||||
|
||||
print(f"Duration: {len(audio)/sr:.2f}s")
|
||||
print(f"Sample rate: {sr} Hz\n")
|
||||
|
||||
# Convert to int16
|
||||
audio_int16 = (audio * 32767).astype(np.int16)
|
||||
|
||||
# Listen for responses in background
|
||||
async def receive_messages():
|
||||
"""Receive and display server messages."""
|
||||
try:
|
||||
while True:
|
||||
response = await websocket.recv()
|
||||
result = json.loads(response)
|
||||
|
||||
msg_type = result.get('type')
|
||||
|
||||
if msg_type == 'vad_status':
|
||||
is_speech = result.get('is_speech')
|
||||
if is_speech:
|
||||
print("\n🎤 VAD: Speech detected\n")
|
||||
else:
|
||||
print("\n🛑 VAD: Speech ended\n")
|
||||
|
||||
elif msg_type == 'transcript':
|
||||
text = result.get('text', '')
|
||||
duration = result.get('duration', 0)
|
||||
is_final = result.get('is_final', False)
|
||||
|
||||
if is_final:
|
||||
print(f"\n{'='*70}")
|
||||
print(f"✅ FINAL TRANSCRIPTION ({duration:.2f}s):")
|
||||
print(f" \"{text}\"")
|
||||
print(f"{'='*70}\n")
|
||||
else:
|
||||
print(f"📝 PARTIAL ({duration:.2f}s): {text}")
|
||||
|
||||
elif msg_type == 'info':
|
||||
print(f"ℹ️ {result.get('message')}")
|
||||
|
||||
elif msg_type == 'error':
|
||||
print(f"❌ Error: {result.get('message')}")
|
||||
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# Start listener
|
||||
listen_task = asyncio.create_task(receive_messages())
|
||||
|
||||
# Send audio in small chunks (simulate streaming)
|
||||
chunk_size = int(sr * 0.1) # 100ms chunks
|
||||
print("Streaming audio...\n")
|
||||
|
||||
for i in range(0, len(audio_int16), chunk_size):
|
||||
chunk = audio_int16[i:i+chunk_size]
|
||||
await websocket.send(chunk.tobytes())
|
||||
await asyncio.sleep(0.05) # Simulate real-time
|
||||
|
||||
print("\nAll audio sent. Waiting for final transcription...")
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(3.0)
|
||||
|
||||
# Force transcribe any remaining buffer
|
||||
print("Sending force_transcribe command...\n")
|
||||
await websocket.send(json.dumps({"type": "force_transcribe"}))
|
||||
|
||||
# Wait a bit more
|
||||
await asyncio.sleep(2.0)
|
||||
|
||||
# Cancel listener
|
||||
listen_task.cancel()
|
||||
try:
|
||||
await listen_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
print("\n✓ Test completed!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
audio_file = sys.argv[1] if len(sys.argv) > 1 else "test.wav"
|
||||
exit_code = asyncio.run(test_vad_server(audio_file))
|
||||
sys.exit(exit_code)
|
||||
@@ -1,219 +0,0 @@
|
||||
"""
|
||||
System diagnostics for ASR setup
|
||||
"""
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
|
||||
def print_section(title):
|
||||
"""Print a section header."""
|
||||
print(f"\n{'='*80}")
|
||||
print(f" {title}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
|
||||
def check_python():
|
||||
"""Check Python version."""
|
||||
print_section("Python Version")
|
||||
print(f"Python: {sys.version}")
|
||||
print(f"Executable: {sys.executable}")
|
||||
|
||||
|
||||
def check_packages():
|
||||
"""Check installed packages."""
|
||||
print_section("Installed Packages")
|
||||
|
||||
packages = [
|
||||
"onnx-asr",
|
||||
"onnxruntime",
|
||||
"onnxruntime-gpu",
|
||||
"numpy",
|
||||
"websockets",
|
||||
"sounddevice",
|
||||
"soundfile",
|
||||
]
|
||||
|
||||
for package in packages:
|
||||
try:
|
||||
if package == "onnx-asr":
|
||||
import onnx_asr
|
||||
version = getattr(onnx_asr, "__version__", "unknown")
|
||||
elif package == "onnxruntime":
|
||||
import onnxruntime
|
||||
version = onnxruntime.__version__
|
||||
elif package == "onnxruntime-gpu":
|
||||
try:
|
||||
import onnxruntime
|
||||
version = onnxruntime.__version__
|
||||
print(f"✓ {package}: {version}")
|
||||
except ImportError:
|
||||
print(f"✗ {package}: Not installed")
|
||||
continue
|
||||
elif package == "numpy":
|
||||
import numpy
|
||||
version = numpy.__version__
|
||||
elif package == "websockets":
|
||||
import websockets
|
||||
version = websockets.__version__
|
||||
elif package == "sounddevice":
|
||||
import sounddevice
|
||||
version = sounddevice.__version__
|
||||
elif package == "soundfile":
|
||||
import soundfile
|
||||
version = soundfile.__version__
|
||||
|
||||
print(f"✓ {package}: {version}")
|
||||
except ImportError:
|
||||
print(f"✗ {package}: Not installed")
|
||||
|
||||
|
||||
def check_cuda():
|
||||
"""Check CUDA availability."""
|
||||
print_section("CUDA Information")
|
||||
|
||||
# Check nvcc
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["nvcc", "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
print("NVCC (CUDA Compiler):")
|
||||
print(result.stdout)
|
||||
except FileNotFoundError:
|
||||
print("✗ nvcc not found - CUDA may not be installed")
|
||||
|
||||
# Check nvidia-smi
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["nvidia-smi"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
print("NVIDIA GPU Information:")
|
||||
print(result.stdout)
|
||||
except FileNotFoundError:
|
||||
print("✗ nvidia-smi not found - NVIDIA drivers may not be installed")
|
||||
|
||||
|
||||
def check_onnxruntime():
|
||||
"""Check ONNX Runtime providers."""
|
||||
print_section("ONNX Runtime Providers")
|
||||
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
|
||||
print("Available providers:")
|
||||
for provider in ort.get_available_providers():
|
||||
print(f" ✓ {provider}")
|
||||
|
||||
# Check if CUDA is available
|
||||
if "CUDAExecutionProvider" in ort.get_available_providers():
|
||||
print("\n✓ GPU acceleration available via CUDA")
|
||||
else:
|
||||
print("\n✗ GPU acceleration NOT available")
|
||||
print(" Make sure onnxruntime-gpu is installed and CUDA is working")
|
||||
|
||||
# Get device info
|
||||
print(f"\nONNX Runtime version: {ort.__version__}")
|
||||
|
||||
except ImportError:
|
||||
print("✗ onnxruntime not installed")
|
||||
|
||||
|
||||
def check_audio_devices():
|
||||
"""Check audio devices."""
|
||||
print_section("Audio Devices")
|
||||
|
||||
try:
|
||||
import sounddevice as sd
|
||||
|
||||
devices = sd.query_devices()
|
||||
|
||||
print("Input devices:")
|
||||
for i, device in enumerate(devices):
|
||||
if device['max_input_channels'] > 0:
|
||||
default = " [DEFAULT]" if i == sd.default.device[0] else ""
|
||||
print(f" [{i}] {device['name']}{default}")
|
||||
print(f" Channels: {device['max_input_channels']}")
|
||||
print(f" Sample rate: {device['default_samplerate']} Hz")
|
||||
|
||||
except ImportError:
|
||||
print("✗ sounddevice not installed")
|
||||
except Exception as e:
|
||||
print(f"✗ Error querying audio devices: {e}")
|
||||
|
||||
|
||||
def check_model_files():
|
||||
"""Check if model files exist."""
|
||||
print_section("Model Files")
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
model_dir = Path("models/parakeet")
|
||||
|
||||
expected_files = [
|
||||
"config.json",
|
||||
"encoder-parakeet-tdt-0.6b-v3.onnx",
|
||||
"decoder_joint-parakeet-tdt-0.6b-v3.onnx",
|
||||
"vocab.txt",
|
||||
]
|
||||
|
||||
if not model_dir.exists():
|
||||
print(f"✗ Model directory not found: {model_dir}")
|
||||
print(" Models will be downloaded on first run")
|
||||
return
|
||||
|
||||
print(f"Model directory: {model_dir.absolute()}")
|
||||
print("\nExpected files:")
|
||||
|
||||
for filename in expected_files:
|
||||
filepath = model_dir / filename
|
||||
if filepath.exists():
|
||||
size_mb = filepath.stat().st_size / (1024 * 1024)
|
||||
print(f" ✓ {filename} ({size_mb:.1f} MB)")
|
||||
else:
|
||||
print(f" ✗ {filename} (missing)")
|
||||
|
||||
|
||||
def test_onnx_asr():
|
||||
"""Test onnx-asr import and basic functionality."""
|
||||
print_section("onnx-asr Test")
|
||||
|
||||
try:
|
||||
import onnx_asr
|
||||
|
||||
print("✓ onnx-asr imported successfully")
|
||||
print(f" Version: {getattr(onnx_asr, '__version__', 'unknown')}")
|
||||
|
||||
# Test loading model info (without downloading)
|
||||
print("\n✓ onnx-asr is ready to use")
|
||||
print(" Run test_offline.py to download models and test transcription")
|
||||
|
||||
except ImportError as e:
|
||||
print(f"✗ Failed to import onnx-asr: {e}")
|
||||
except Exception as e:
|
||||
print(f"✗ Error testing onnx-asr: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all diagnostics."""
|
||||
print("\n" + "="*80)
|
||||
print(" ASR System Diagnostics")
|
||||
print("="*80)
|
||||
|
||||
check_python()
|
||||
check_packages()
|
||||
check_cuda()
|
||||
check_onnxruntime()
|
||||
check_audio_devices()
|
||||
check_model_files()
|
||||
test_onnx_asr()
|
||||
|
||||
print("\n" + "="*80)
|
||||
print(" Diagnostics Complete")
|
||||
print("="*80 + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,114 +0,0 @@
|
||||
"""
|
||||
Test offline ASR pipeline with onnx-asr
|
||||
"""
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from asr.asr_pipeline import ASRPipeline
|
||||
|
||||
|
||||
def test_transcription(audio_file: str, use_vad: bool = False, quantization: str = None):
|
||||
"""
|
||||
Test ASR transcription on an audio file.
|
||||
|
||||
Args:
|
||||
audio_file: Path to audio file
|
||||
use_vad: Whether to use VAD
|
||||
quantization: Optional quantization (e.g., "int8")
|
||||
"""
|
||||
print(f"\n{'='*80}")
|
||||
print(f"Testing ASR Pipeline with onnx-asr")
|
||||
print(f"{'='*80}")
|
||||
print(f"Audio file: {audio_file}")
|
||||
print(f"Use VAD: {use_vad}")
|
||||
print(f"Quantization: {quantization}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
# Initialize pipeline
|
||||
print("Initializing ASR pipeline...")
|
||||
pipeline = ASRPipeline(
|
||||
model_name="nemo-parakeet-tdt-0.6b-v3",
|
||||
quantization=quantization,
|
||||
use_vad=use_vad,
|
||||
)
|
||||
print("Pipeline initialized successfully!\n")
|
||||
|
||||
# Read audio file
|
||||
print(f"Reading audio file: {audio_file}")
|
||||
audio, sr = sf.read(audio_file, dtype="float32")
|
||||
print(f"Sample rate: {sr} Hz")
|
||||
print(f"Audio shape: {audio.shape}")
|
||||
print(f"Audio duration: {len(audio) / sr:.2f} seconds")
|
||||
|
||||
# Ensure mono
|
||||
if audio.ndim > 1:
|
||||
print("Converting stereo to mono...")
|
||||
audio = audio[:, 0]
|
||||
|
||||
# Verify sample rate
|
||||
if sr != 16000:
|
||||
print(f"WARNING: Sample rate is {sr} Hz, expected 16000 Hz")
|
||||
print("Consider resampling the audio file")
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print("Transcribing...")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
# Transcribe
|
||||
result = pipeline.transcribe(audio, sample_rate=sr)
|
||||
|
||||
# Display results
|
||||
if use_vad and isinstance(result, list):
|
||||
print("TRANSCRIPTION (with VAD):")
|
||||
print("-" * 80)
|
||||
for i, segment in enumerate(result, 1):
|
||||
print(f"Segment {i}: {segment}")
|
||||
print("-" * 80)
|
||||
else:
|
||||
print("TRANSCRIPTION:")
|
||||
print("-" * 80)
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
# Audio statistics
|
||||
print(f"\nAUDIO STATISTICS:")
|
||||
print(f" dtype: {audio.dtype}")
|
||||
print(f" min: {audio.min():.6f}")
|
||||
print(f" max: {audio.max():.6f}")
|
||||
print(f" mean: {audio.mean():.6f}")
|
||||
print(f" std: {audio.std():.6f}")
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print("Test completed successfully!")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Test offline ASR transcription")
|
||||
parser.add_argument("audio_file", help="Path to audio file (WAV format)")
|
||||
parser.add_argument("--use-vad", action="store_true", help="Enable VAD")
|
||||
parser.add_argument("--quantization", default=None, choices=["int8", "fp16"],
|
||||
help="Model quantization")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Check if file exists
|
||||
if not Path(args.audio_file).exists():
|
||||
print(f"ERROR: Audio file not found: {args.audio_file}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
test_transcription(args.audio_file, args.use_vad, args.quantization)
|
||||
except Exception as e:
|
||||
print(f"\nERROR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,6 +0,0 @@
|
||||
"""
|
||||
VAD module using onnx-asr library
|
||||
"""
|
||||
from .silero_vad import SileroVAD, load_vad
|
||||
|
||||
__all__ = ["SileroVAD", "load_vad"]
|
||||
@@ -1,114 +0,0 @@
|
||||
"""
|
||||
Silero VAD wrapper using onnx-asr library
|
||||
"""
|
||||
import numpy as np
|
||||
import onnx_asr
|
||||
from typing import Optional, Tuple
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SileroVAD:
|
||||
"""
|
||||
Voice Activity Detection using Silero VAD via onnx-asr.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
providers: Optional[list] = None,
|
||||
threshold: float = 0.5,
|
||||
min_speech_duration_ms: int = 250,
|
||||
min_silence_duration_ms: int = 100,
|
||||
window_size_samples: int = 512,
|
||||
speech_pad_ms: int = 30,
|
||||
):
|
||||
"""
|
||||
Initialize Silero VAD.
|
||||
|
||||
Args:
|
||||
providers: Optional ONNX runtime providers
|
||||
threshold: Speech probability threshold (0.0-1.0)
|
||||
min_speech_duration_ms: Minimum duration of speech segment
|
||||
min_silence_duration_ms: Minimum duration of silence to split segments
|
||||
window_size_samples: Window size for VAD processing
|
||||
speech_pad_ms: Padding around speech segments
|
||||
"""
|
||||
if providers is None:
|
||||
providers = [
|
||||
"CUDAExecutionProvider",
|
||||
"CPUExecutionProvider",
|
||||
]
|
||||
|
||||
logger.info("Loading Silero VAD model...")
|
||||
self.vad = onnx_asr.load_vad("silero", providers=providers)
|
||||
|
||||
# VAD parameters
|
||||
self.threshold = threshold
|
||||
self.min_speech_duration_ms = min_speech_duration_ms
|
||||
self.min_silence_duration_ms = min_silence_duration_ms
|
||||
self.window_size_samples = window_size_samples
|
||||
self.speech_pad_ms = speech_pad_ms
|
||||
|
||||
logger.info("Silero VAD initialized successfully")
|
||||
|
||||
def detect_speech(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
sample_rate: int = 16000,
|
||||
) -> list:
|
||||
"""
|
||||
Detect speech segments in audio.
|
||||
|
||||
Args:
|
||||
audio: Audio data as numpy array (float32)
|
||||
sample_rate: Sample rate of audio
|
||||
|
||||
Returns:
|
||||
List of tuples (start_sample, end_sample) for speech segments
|
||||
"""
|
||||
# Note: The actual VAD processing is typically done within
|
||||
# the onnx_asr model.with_vad() method, but we provide
|
||||
# this interface for direct VAD usage
|
||||
|
||||
# For direct VAD detection, you would use the vad model directly
|
||||
# However, onnx-asr integrates VAD into the recognition pipeline
|
||||
# So this is mainly for compatibility
|
||||
|
||||
logger.warning("Direct VAD detection - consider using model.with_vad() instead")
|
||||
return []
|
||||
|
||||
def is_speech(
|
||||
self,
|
||||
audio_chunk: np.ndarray,
|
||||
sample_rate: int = 16000,
|
||||
) -> Tuple[bool, float]:
|
||||
"""
|
||||
Check if audio chunk contains speech.
|
||||
|
||||
Args:
|
||||
audio_chunk: Audio chunk as numpy array (float32)
|
||||
sample_rate: Sample rate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_speech: bool, probability: float)
|
||||
"""
|
||||
# Placeholder for direct VAD probability check
|
||||
# In practice, use model.with_vad() for automatic segmentation
|
||||
logger.warning("Direct speech detection not implemented - use model.with_vad()")
|
||||
return False, 0.0
|
||||
|
||||
def get_vad(self):
|
||||
"""
|
||||
Get the underlying onnx_asr VAD model.
|
||||
|
||||
Returns:
|
||||
The onnx_asr VAD model instance
|
||||
"""
|
||||
return self.vad
|
||||
|
||||
|
||||
# Convenience function
|
||||
def load_vad(**kwargs):
|
||||
"""Load and return Silero VAD with given configuration."""
|
||||
return SileroVAD(**kwargs)
|
||||
@@ -1,44 +0,0 @@
|
||||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.11 \
|
||||
python3-pip \
|
||||
ffmpeg \
|
||||
libsndfile1 \
|
||||
sox \
|
||||
libsox-dev \
|
||||
libsox-fmt-all \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements
|
||||
COPY requirements.txt .
|
||||
|
||||
# Upgrade pip to avoid dependency resolution issues
|
||||
RUN pip3 install --upgrade pip
|
||||
|
||||
# Install dependencies for sox package (required by NeMo) in correct order
|
||||
RUN pip3 install --no-cache-dir numpy==2.2.2 typing-extensions
|
||||
|
||||
# Install Python dependencies with legacy resolver (NeMo has complex dependencies)
|
||||
RUN pip3 install --no-cache-dir --use-deprecated=legacy-resolver -r requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
# Create models directory
|
||||
RUN mkdir -p /models
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV CUDA_VISIBLE_DEVICES=0
|
||||
ENV LD_LIBRARY_PATH=/usr/local/lib/python3.11/dist-packages/nvidia/cudnn/lib:${LD_LIBRARY_PATH}
|
||||
|
||||
# Run the server
|
||||
CMD ["uvicorn", "stt_server:app", "--host", "0.0.0.0", "--port", "8000", "--log-level", "info"]
|
||||
@@ -1,114 +0,0 @@
|
||||
# NVIDIA Parakeet Migration
|
||||
|
||||
## Summary
|
||||
|
||||
Replaced Faster-Whisper with NVIDIA Parakeet TDT (Token-and-Duration Transducer) for real-time speech transcription.
|
||||
|
||||
## Changes Made
|
||||
|
||||
### 1. New Transcriber: `parakeet_transcriber.py`
|
||||
- **Model**: `nvidia/parakeet-tdt-0.6b-v3` (600M parameters)
|
||||
- **Features**:
|
||||
- Real-time streaming transcription
|
||||
- Word-level timestamps for LLM pre-computation
|
||||
- GPU-accelerated (CUDA)
|
||||
- Lower latency than Faster-Whisper
|
||||
- Native PyTorch (no CTranslate2 dependency)
|
||||
|
||||
### 2. Requirements Updated
|
||||
**Removed**:
|
||||
- `faster-whisper==1.2.1`
|
||||
- `ctranslate2==4.5.0`
|
||||
|
||||
**Added**:
|
||||
- `transformers==4.47.1` - HuggingFace model loading
|
||||
- `accelerate==1.2.1` - GPU optimization
|
||||
- `sentencepiece==0.2.0` - Tokenization
|
||||
|
||||
**Kept**:
|
||||
- `torch==2.9.1` & `torchaudio==2.9.1` - Core ML framework
|
||||
- `silero-vad==5.1.2` - VAD still uses Silero (CPU)
|
||||
|
||||
### 3. Server Updates: `stt_server.py`
|
||||
**Changes**:
|
||||
- Import `ParakeetTranscriber` instead of `WhisperTranscriber`
|
||||
- Partial transcripts now include `words` array with timestamps
|
||||
- Final transcripts include `words` array for LLM pre-computation
|
||||
- Startup logs show "Loading NVIDIA Parakeet TDT model"
|
||||
|
||||
**Word-level Token Format**:
|
||||
```json
|
||||
{
|
||||
"type": "partial",
|
||||
"text": "hello world",
|
||||
"words": [
|
||||
{"word": "hello", "start_time": 0.0, "end_time": 0.5},
|
||||
{"word": "world", "start_time": 0.5, "end_time": 1.0}
|
||||
],
|
||||
"user_id": "123",
|
||||
"timestamp": 1234.56
|
||||
}
|
||||
```
|
||||
|
||||
## Advantages Over Faster-Whisper
|
||||
|
||||
1. **Real-time Performance**: TDT architecture designed for streaming
|
||||
2. **No cuDNN Issues**: Native PyTorch, no CTranslate2 library loading problems
|
||||
3. **Word-level Tokens**: Enables LLM prompt pre-computation during speech
|
||||
4. **Lower Latency**: Optimized for real-time use cases
|
||||
5. **Better GPU Utilization**: Uses standard PyTorch CUDA
|
||||
6. **Simpler Dependencies**: No external compiled libraries
|
||||
|
||||
## Deployment
|
||||
|
||||
1. **Build Container**:
|
||||
```bash
|
||||
docker-compose build miku-stt
|
||||
```
|
||||
|
||||
2. **First Run** (downloads model ~600MB):
|
||||
```bash
|
||||
docker-compose up miku-stt
|
||||
```
|
||||
Model will be cached in `/models` volume for subsequent runs.
|
||||
|
||||
3. **Verify GPU Usage**:
|
||||
```bash
|
||||
docker exec miku-stt nvidia-smi
|
||||
```
|
||||
You should see `python3` process using VRAM (~1.5GB for model + inference).
|
||||
|
||||
## Testing
|
||||
|
||||
Same test procedure as before:
|
||||
1. Join voice channel
|
||||
2. `!miku listen`
|
||||
3. Speak clearly
|
||||
4. Check logs for "Parakeet model loaded"
|
||||
5. Verify transcripts appear faster than before
|
||||
|
||||
## Bot-Side Compatibility
|
||||
|
||||
No changes needed to bot code - STT WebSocket protocol is identical. The bot will automatically receive word-level tokens in partial/final transcript messages.
|
||||
|
||||
### Future Enhancement: LLM Pre-computation
|
||||
The `words` array can be used to start LLM inference before full transcript completes:
|
||||
- Send partial words to LLM as they arrive
|
||||
- LLM begins processing prompt tokens
|
||||
- Faster response time when user finishes speaking
|
||||
|
||||
## Rollback (if needed)
|
||||
|
||||
To revert to Faster-Whisper:
|
||||
1. Restore `requirements.txt` from git
|
||||
2. Restore `stt_server.py` from git
|
||||
3. Delete `parakeet_transcriber.py`
|
||||
4. Rebuild container
|
||||
|
||||
## Performance Expectations
|
||||
|
||||
- **Model Load Time**: ~5-10 seconds (first time downloads from HuggingFace)
|
||||
- **VRAM Usage**: ~1.5GB (vs ~800MB for Whisper small)
|
||||
- **Latency**: ~200-500ms for 2-second audio chunks
|
||||
- **GPU Utilization**: 30-60% during active transcription
|
||||
- **Accuracy**: Similar to Whisper small (designed for English)
|
||||
152
stt/README.md
152
stt/README.md
@@ -1,152 +0,0 @@
|
||||
# Miku STT (Speech-to-Text) Server
|
||||
|
||||
Real-time speech-to-text service for Miku voice chat using Silero VAD (CPU) and Faster-Whisper (GPU).
|
||||
|
||||
## Architecture
|
||||
|
||||
- **Silero VAD** (CPU): Lightweight voice activity detection, runs continuously
|
||||
- **Faster-Whisper** (GPU GTX 1660): Efficient speech transcription using CTranslate2
|
||||
- **FastAPI WebSocket**: Real-time bidirectional communication
|
||||
|
||||
## Features
|
||||
|
||||
- ✅ Real-time voice activity detection with conservative settings
|
||||
- ✅ Streaming partial transcripts during speech
|
||||
- ✅ Final transcript on speech completion
|
||||
- ✅ Interruption detection (user speaking over Miku)
|
||||
- ✅ Multi-user support with isolated sessions
|
||||
- ✅ KV cache optimization ready (partial text for LLM precomputation)
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### WebSocket: `/ws/stt/{user_id}`
|
||||
|
||||
Real-time STT session for a specific user.
|
||||
|
||||
**Client sends:** Raw PCM audio (int16, 16kHz mono, 20ms chunks = 320 samples)
|
||||
|
||||
**Server sends:** JSON events:
|
||||
```json
|
||||
// VAD events
|
||||
{"type": "vad", "event": "speech_start", "speaking": true, "probability": 0.85, "timestamp": 1250.5}
|
||||
{"type": "vad", "event": "speaking", "speaking": true, "probability": 0.92, "timestamp": 1270.5}
|
||||
{"type": "vad", "event": "speech_end", "speaking": false, "probability": 0.35, "timestamp": 3500.0}
|
||||
|
||||
// Transcription events
|
||||
{"type": "partial", "text": "Hello how are", "user_id": "123", "timestamp": 2000.0}
|
||||
{"type": "final", "text": "Hello how are you?", "user_id": "123", "timestamp": 3500.0}
|
||||
|
||||
// Interruption detection
|
||||
{"type": "interruption", "probability": 0.92, "timestamp": 1500.0}
|
||||
```
|
||||
|
||||
### HTTP GET: `/health`
|
||||
|
||||
Health check with model status.
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"status": "healthy",
|
||||
"models": {
|
||||
"vad": {"loaded": true, "device": "cpu"},
|
||||
"whisper": {"loaded": true, "model": "small", "device": "cuda"}
|
||||
},
|
||||
"sessions": {
|
||||
"active": 2,
|
||||
"users": ["user123", "user456"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### VAD Parameters (Conservative)
|
||||
|
||||
- **Threshold**: 0.5 (speech probability)
|
||||
- **Min speech duration**: 250ms (avoid false triggers)
|
||||
- **Min silence duration**: 500ms (don't cut off mid-sentence)
|
||||
- **Speech padding**: 30ms (context around speech)
|
||||
|
||||
### Whisper Parameters
|
||||
|
||||
- **Model**: small (balanced speed/quality, ~500MB VRAM)
|
||||
- **Compute**: float16 (GPU optimization)
|
||||
- **Language**: en (English)
|
||||
- **Beam size**: 5 (quality/speed balance)
|
||||
|
||||
## Usage Example
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import websockets
|
||||
import numpy as np
|
||||
|
||||
async def stream_audio():
|
||||
uri = "ws://localhost:8001/ws/stt/user123"
|
||||
|
||||
async with websockets.connect(uri) as websocket:
|
||||
# Wait for ready
|
||||
ready = await websocket.recv()
|
||||
print(ready)
|
||||
|
||||
# Stream audio chunks (16kHz, 20ms chunks)
|
||||
for audio_chunk in audio_stream:
|
||||
# Convert to bytes (int16)
|
||||
audio_bytes = audio_chunk.astype(np.int16).tobytes()
|
||||
await websocket.send(audio_bytes)
|
||||
|
||||
# Receive events
|
||||
event = await websocket.recv()
|
||||
print(event)
|
||||
|
||||
asyncio.run(stream_audio())
|
||||
```
|
||||
|
||||
## Docker Setup
|
||||
|
||||
### Build
|
||||
```bash
|
||||
docker-compose build miku-stt
|
||||
```
|
||||
|
||||
### Run
|
||||
```bash
|
||||
docker-compose up -d miku-stt
|
||||
```
|
||||
|
||||
### Logs
|
||||
```bash
|
||||
docker-compose logs -f miku-stt
|
||||
```
|
||||
|
||||
### Test
|
||||
```bash
|
||||
curl http://localhost:8001/health
|
||||
```
|
||||
|
||||
## GPU Sharing with Soprano
|
||||
|
||||
Both STT (Whisper) and TTS (Soprano) run on GTX 1660 but at different times:
|
||||
|
||||
1. **User speaking** → Whisper active, Soprano idle
|
||||
2. **LLM processing** → Both idle
|
||||
3. **Miku speaking** → Soprano active, Whisper idle (VAD monitoring only)
|
||||
|
||||
Interruption detection runs VAD continuously but doesn't use GPU.
|
||||
|
||||
## Performance
|
||||
|
||||
- **VAD latency**: 10-20ms per chunk (CPU)
|
||||
- **Whisper latency**: ~1-2s for 2s audio (GPU)
|
||||
- **Memory usage**:
|
||||
- Silero VAD: ~100MB (CPU)
|
||||
- Faster-Whisper small: ~500MB (GPU VRAM)
|
||||
|
||||
## Future Improvements
|
||||
|
||||
- [ ] Multi-language support (auto-detect)
|
||||
- [ ] Word-level timestamps for better sync
|
||||
- [ ] Custom vocabulary/prompt tuning
|
||||
- [ ] Speaker diarization (multiple speakers)
|
||||
- [ ] Noise suppression preprocessing
|
||||
Binary file not shown.
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,239 +0,0 @@
|
||||
{
|
||||
"alignment_heads": [
|
||||
[
|
||||
5,
|
||||
3
|
||||
],
|
||||
[
|
||||
5,
|
||||
9
|
||||
],
|
||||
[
|
||||
8,
|
||||
0
|
||||
],
|
||||
[
|
||||
8,
|
||||
4
|
||||
],
|
||||
[
|
||||
8,
|
||||
7
|
||||
],
|
||||
[
|
||||
8,
|
||||
8
|
||||
],
|
||||
[
|
||||
9,
|
||||
0
|
||||
],
|
||||
[
|
||||
9,
|
||||
7
|
||||
],
|
||||
[
|
||||
9,
|
||||
9
|
||||
],
|
||||
[
|
||||
10,
|
||||
5
|
||||
]
|
||||
],
|
||||
"lang_ids": [
|
||||
50259,
|
||||
50260,
|
||||
50261,
|
||||
50262,
|
||||
50263,
|
||||
50264,
|
||||
50265,
|
||||
50266,
|
||||
50267,
|
||||
50268,
|
||||
50269,
|
||||
50270,
|
||||
50271,
|
||||
50272,
|
||||
50273,
|
||||
50274,
|
||||
50275,
|
||||
50276,
|
||||
50277,
|
||||
50278,
|
||||
50279,
|
||||
50280,
|
||||
50281,
|
||||
50282,
|
||||
50283,
|
||||
50284,
|
||||
50285,
|
||||
50286,
|
||||
50287,
|
||||
50288,
|
||||
50289,
|
||||
50290,
|
||||
50291,
|
||||
50292,
|
||||
50293,
|
||||
50294,
|
||||
50295,
|
||||
50296,
|
||||
50297,
|
||||
50298,
|
||||
50299,
|
||||
50300,
|
||||
50301,
|
||||
50302,
|
||||
50303,
|
||||
50304,
|
||||
50305,
|
||||
50306,
|
||||
50307,
|
||||
50308,
|
||||
50309,
|
||||
50310,
|
||||
50311,
|
||||
50312,
|
||||
50313,
|
||||
50314,
|
||||
50315,
|
||||
50316,
|
||||
50317,
|
||||
50318,
|
||||
50319,
|
||||
50320,
|
||||
50321,
|
||||
50322,
|
||||
50323,
|
||||
50324,
|
||||
50325,
|
||||
50326,
|
||||
50327,
|
||||
50328,
|
||||
50329,
|
||||
50330,
|
||||
50331,
|
||||
50332,
|
||||
50333,
|
||||
50334,
|
||||
50335,
|
||||
50336,
|
||||
50337,
|
||||
50338,
|
||||
50339,
|
||||
50340,
|
||||
50341,
|
||||
50342,
|
||||
50343,
|
||||
50344,
|
||||
50345,
|
||||
50346,
|
||||
50347,
|
||||
50348,
|
||||
50349,
|
||||
50350,
|
||||
50351,
|
||||
50352,
|
||||
50353,
|
||||
50354,
|
||||
50355,
|
||||
50356,
|
||||
50357
|
||||
],
|
||||
"suppress_ids": [
|
||||
1,
|
||||
2,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
14,
|
||||
25,
|
||||
26,
|
||||
27,
|
||||
28,
|
||||
29,
|
||||
31,
|
||||
58,
|
||||
59,
|
||||
60,
|
||||
61,
|
||||
62,
|
||||
63,
|
||||
90,
|
||||
91,
|
||||
92,
|
||||
93,
|
||||
359,
|
||||
503,
|
||||
522,
|
||||
542,
|
||||
873,
|
||||
893,
|
||||
902,
|
||||
918,
|
||||
922,
|
||||
931,
|
||||
1350,
|
||||
1853,
|
||||
1982,
|
||||
2460,
|
||||
2627,
|
||||
3246,
|
||||
3253,
|
||||
3268,
|
||||
3536,
|
||||
3846,
|
||||
3961,
|
||||
4183,
|
||||
4667,
|
||||
6585,
|
||||
6647,
|
||||
7273,
|
||||
9061,
|
||||
9383,
|
||||
10428,
|
||||
10929,
|
||||
11938,
|
||||
12033,
|
||||
12331,
|
||||
12562,
|
||||
13793,
|
||||
14157,
|
||||
14635,
|
||||
15265,
|
||||
15618,
|
||||
16553,
|
||||
16604,
|
||||
18362,
|
||||
18956,
|
||||
20075,
|
||||
21675,
|
||||
22520,
|
||||
26130,
|
||||
26161,
|
||||
26435,
|
||||
28279,
|
||||
29464,
|
||||
31650,
|
||||
32302,
|
||||
32470,
|
||||
36865,
|
||||
42863,
|
||||
47425,
|
||||
49870,
|
||||
50254,
|
||||
50258,
|
||||
50358,
|
||||
50359,
|
||||
50360,
|
||||
50361,
|
||||
50362
|
||||
],
|
||||
"suppress_ids_begin": [
|
||||
220,
|
||||
50257
|
||||
]
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
536b0662742c02347bc0e980a01041f333bce120
|
||||
@@ -1 +0,0 @@
|
||||
../../blobs/e5047537059bd8f182d9ca64c470201585015187
|
||||
@@ -1 +0,0 @@
|
||||
../../blobs/3e305921506d8872816023e4c273e75d2419fb89b24da97b4fe7bce14170d671
|
||||
@@ -1 +0,0 @@
|
||||
../../blobs/7818adb6de9fa3064d3ff81226fdd675be1f6344
|
||||
@@ -1 +0,0 @@
|
||||
../../blobs/c9074644d9d1205686f16d411564729461324b75
|
||||
@@ -1 +0,0 @@
|
||||
6d590f77001d318fb17a0b5bf7ee329a91b52598
|
||||
@@ -1,229 +0,0 @@
|
||||
"""
|
||||
NVIDIA Parakeet TDT Transcriber
|
||||
|
||||
Real-time streaming ASR using NVIDIA's Parakeet TDT (Token-and-Duration Transducer) model.
|
||||
Supports streaming transcription with word-level timestamps for LLM pre-computation.
|
||||
|
||||
Model: nvidia/parakeet-tdt-0.6b-v3
|
||||
- 600M parameters
|
||||
- Real-time capable on GPU
|
||||
- Word-level timestamps
|
||||
- Streaming support via NVIDIA NeMo
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from nemo.collections.asr.models import EncDecRNNTBPEModel
|
||||
from typing import Optional, List, Dict
|
||||
import logging
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
logger = logging.getLogger('parakeet')
|
||||
|
||||
|
||||
class ParakeetTranscriber:
|
||||
"""
|
||||
NVIDIA Parakeet-based streaming transcription with word-level tokens.
|
||||
|
||||
Uses NVIDIA NeMo for proper model loading and inference.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "nvidia/parakeet-tdt-0.6b-v3",
|
||||
device: str = "cuda",
|
||||
language: str = "en"
|
||||
):
|
||||
"""
|
||||
Initialize Parakeet transcriber.
|
||||
|
||||
Args:
|
||||
model_name: HuggingFace model identifier
|
||||
device: Device to run on (cuda or cpu)
|
||||
language: Language code (Parakeet primarily supports English)
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
self.language = language
|
||||
|
||||
logger.info(f"Loading Parakeet model: {model_name} on {device}...")
|
||||
|
||||
# Set PyTorch memory allocator settings for better memory management
|
||||
if device == "cuda":
|
||||
# Enable expandable segments to reduce fragmentation
|
||||
import os
|
||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
||||
|
||||
# Clear cache before loading model
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Load model via NeMo from HuggingFace
|
||||
self.model = EncDecRNNTBPEModel.from_pretrained(
|
||||
model_name=model_name,
|
||||
map_location=device
|
||||
)
|
||||
|
||||
self.model.eval()
|
||||
if device == "cuda":
|
||||
self.model = self.model.cuda()
|
||||
# Enable memory efficient attention if available
|
||||
try:
|
||||
self.model.encoder.use_memory_efficient_attention = True
|
||||
except:
|
||||
pass
|
||||
|
||||
# Thread pool for blocking transcription calls
|
||||
self.executor = ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
logger.info(f"Parakeet model loaded on {device}")
|
||||
|
||||
async def transcribe_async(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
sample_rate: int = 16000,
|
||||
return_timestamps: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Transcribe audio asynchronously (non-blocking).
|
||||
|
||||
Args:
|
||||
audio: Audio data as numpy array (float32)
|
||||
sample_rate: Audio sample rate (Parakeet expects 16kHz)
|
||||
return_timestamps: Whether to return word-level timestamps
|
||||
|
||||
Returns:
|
||||
Transcribed text (or dict with timestamps if return_timestamps=True)
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Run transcription in thread pool to avoid blocking
|
||||
result = await loop.run_in_executor(
|
||||
self.executor,
|
||||
self._transcribe_blocking,
|
||||
audio,
|
||||
sample_rate,
|
||||
return_timestamps
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _transcribe_blocking(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
sample_rate: int,
|
||||
return_timestamps: bool
|
||||
):
|
||||
"""
|
||||
Blocking transcription call (runs in thread pool).
|
||||
"""
|
||||
# Convert to float32 if needed
|
||||
if audio.dtype != np.float32:
|
||||
audio = audio.astype(np.float32) / 32768.0
|
||||
|
||||
# Ensure correct sample rate (Parakeet expects 16kHz)
|
||||
if sample_rate != 16000:
|
||||
logger.warning(f"Audio sample rate is {sample_rate}Hz, Parakeet expects 16kHz. Resampling...")
|
||||
import torchaudio
|
||||
audio_tensor = torch.from_numpy(audio).unsqueeze(0)
|
||||
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
|
||||
audio_tensor = resampler(audio_tensor)
|
||||
audio = audio_tensor.squeeze(0).numpy()
|
||||
sample_rate = 16000
|
||||
|
||||
# Transcribe using NeMo model
|
||||
with torch.no_grad():
|
||||
# Convert to tensor and keep on GPU to avoid CPU/GPU bouncing
|
||||
audio_signal = torch.from_numpy(audio).unsqueeze(0)
|
||||
audio_signal_len = torch.tensor([len(audio)])
|
||||
|
||||
if self.device == "cuda":
|
||||
audio_signal = audio_signal.cuda()
|
||||
audio_signal_len = audio_signal_len.cuda()
|
||||
|
||||
# Get transcription
|
||||
# NeMo returns list of Hypothesis objects
|
||||
# Note: timestamps=True causes significant VRAM usage (~1-2GB extra)
|
||||
# Only enable for final transcriptions, not streaming partials
|
||||
transcriptions = self.model.transcribe(
|
||||
audio=[audio], # Pass NumPy array directly (NeMo handles it efficiently)
|
||||
batch_size=1,
|
||||
timestamps=return_timestamps # Only use timestamps when explicitly requested
|
||||
)
|
||||
|
||||
# Extract text from Hypothesis object
|
||||
hypothesis = transcriptions[0] if transcriptions else None
|
||||
if hypothesis is None:
|
||||
text = ""
|
||||
words = []
|
||||
else:
|
||||
# Hypothesis object has .text attribute
|
||||
text = hypothesis.text.strip() if hasattr(hypothesis, 'text') else str(hypothesis).strip()
|
||||
|
||||
# Extract word-level timestamps if available and requested
|
||||
words = []
|
||||
if return_timestamps and hasattr(hypothesis, 'timestamp') and hypothesis.timestamp:
|
||||
# timestamp is a dict with 'word' key containing list of word timestamps
|
||||
word_timestamps = hypothesis.timestamp.get('word', [])
|
||||
for word_info in word_timestamps:
|
||||
words.append({
|
||||
"word": word_info.get('word', ''),
|
||||
"start_time": word_info.get('start', 0.0),
|
||||
"end_time": word_info.get('end', 0.0)
|
||||
})
|
||||
|
||||
logger.debug(f"Transcribed: '{text}' with {len(words)} words")
|
||||
|
||||
if return_timestamps:
|
||||
return {
|
||||
"text": text,
|
||||
"words": words
|
||||
}
|
||||
else:
|
||||
return text
|
||||
|
||||
# Note: We do NOT call torch.cuda.empty_cache() here
|
||||
# That breaks PyTorch's memory allocator and causes fragmentation
|
||||
# Let PyTorch manage its own memory pool
|
||||
|
||||
async def transcribe_streaming(
|
||||
self,
|
||||
audio_chunks: List[np.ndarray],
|
||||
sample_rate: int = 16000,
|
||||
chunk_size_ms: int = 500
|
||||
) -> Dict[str, any]:
|
||||
"""
|
||||
Transcribe audio chunks with streaming support.
|
||||
|
||||
Args:
|
||||
audio_chunks: List of audio chunks to process
|
||||
sample_rate: Audio sample rate
|
||||
chunk_size_ms: Size of each chunk in milliseconds
|
||||
|
||||
Returns:
|
||||
Dict with partial and word-level results
|
||||
"""
|
||||
if not audio_chunks:
|
||||
return {"text": "", "words": []}
|
||||
|
||||
# Concatenate all chunks
|
||||
audio_data = np.concatenate(audio_chunks)
|
||||
|
||||
# Transcribe with timestamps for streaming
|
||||
result = await self.transcribe_async(
|
||||
audio_data,
|
||||
sample_rate,
|
||||
return_timestamps=True
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_supported_languages(self) -> List[str]:
|
||||
"""Get list of supported language codes."""
|
||||
# Parakeet TDT v3 primarily supports English
|
||||
return ["en"]
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources."""
|
||||
self.executor.shutdown(wait=True)
|
||||
logger.info("Parakeet transcriber cleaned up")
|
||||
@@ -1,29 +0,0 @@
|
||||
# STT Container Requirements
|
||||
|
||||
# Core dependencies
|
||||
fastapi==0.115.6
|
||||
uvicorn[standard]==0.32.1
|
||||
websockets==14.1
|
||||
aiohttp==3.11.11
|
||||
|
||||
# Audio processing (install numpy first for sox dependency)
|
||||
numpy==2.2.2
|
||||
soundfile==0.12.1
|
||||
librosa==0.10.2.post1
|
||||
|
||||
# VAD (CPU)
|
||||
torch==2.9.1 # Latest PyTorch
|
||||
torchaudio==2.9.1
|
||||
silero-vad==5.1.2
|
||||
|
||||
# STT (GPU) - NVIDIA NeMo for Parakeet
|
||||
# Parakeet TDT 0.6b-v3 requires NeMo 2.4
|
||||
# Fix huggingface-hub version conflict with transformers
|
||||
huggingface-hub>=0.30.0,<1.0
|
||||
nemo_toolkit[asr]==2.4.0
|
||||
omegaconf==2.3.0
|
||||
cuda-python>=12.3 # Enable CUDA graphs for faster decoding
|
||||
|
||||
# Utilities
|
||||
python-multipart==0.0.20
|
||||
pydantic==2.10.4
|
||||
@@ -1,396 +0,0 @@
|
||||
"""
|
||||
STT Server
|
||||
|
||||
FastAPI WebSocket server for real-time speech-to-text.
|
||||
Combines Silero VAD (CPU) and NVIDIA Parakeet (GPU) for efficient transcription.
|
||||
|
||||
Architecture:
|
||||
- VAD runs continuously on every audio chunk (CPU)
|
||||
- Parakeet transcribes only when VAD detects speech (GPU)
|
||||
- Supports multiple concurrent users
|
||||
- Sends partial and final transcripts via WebSocket with word-level tokens
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
import numpy as np
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from vad_processor import VADProcessor
|
||||
from parakeet_transcriber import ParakeetTranscriber
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='[%(levelname)s] [%(name)s] %(message)s'
|
||||
)
|
||||
logger = logging.getLogger('stt_server')
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(title="Miku STT Server", version="1.0.0")
|
||||
|
||||
# Global instances (initialized on startup)
|
||||
vad_processor: Optional[VADProcessor] = None
|
||||
parakeet_transcriber: Optional[ParakeetTranscriber] = None
|
||||
|
||||
# User session tracking
|
||||
user_sessions: Dict[str, dict] = {}
|
||||
|
||||
|
||||
class UserSTTSession:
|
||||
"""Manages STT state for a single user."""
|
||||
|
||||
def __init__(self, user_id: str, websocket: WebSocket):
|
||||
self.user_id = user_id
|
||||
self.websocket = websocket
|
||||
self.audio_buffer = []
|
||||
self.is_speaking = False
|
||||
self.timestamp_ms = 0.0
|
||||
self.transcript_buffer = []
|
||||
self.last_transcript = ""
|
||||
self.last_partial_duration = 0.0 # Track when we last sent a partial
|
||||
self.last_speech_timestamp = 0.0 # Track last time we detected speech
|
||||
self.speech_timeout_ms = 3000 # Force finalization after 3s of no new speech
|
||||
|
||||
logger.info(f"Created STT session for user {user_id}")
|
||||
|
||||
async def process_audio_chunk(self, audio_data: bytes):
|
||||
"""
|
||||
Process incoming audio chunk.
|
||||
|
||||
Args:
|
||||
audio_data: Raw PCM audio (int16, 16kHz mono)
|
||||
"""
|
||||
# Convert bytes to numpy array (int16)
|
||||
audio_np = np.frombuffer(audio_data, dtype=np.int16)
|
||||
|
||||
# Calculate timestamp (assuming 16kHz, 20ms chunks = 320 samples)
|
||||
chunk_duration_ms = (len(audio_np) / 16000) * 1000
|
||||
self.timestamp_ms += chunk_duration_ms
|
||||
|
||||
# Run VAD on chunk
|
||||
vad_event = vad_processor.detect_speech_segment(audio_np, self.timestamp_ms)
|
||||
|
||||
if vad_event:
|
||||
event_type = vad_event["event"]
|
||||
probability = vad_event["probability"]
|
||||
|
||||
logger.debug(f"VAD event for user {self.user_id}: {event_type} (prob={probability:.3f})")
|
||||
|
||||
# Send VAD event to client
|
||||
await self.websocket.send_json({
|
||||
"type": "vad",
|
||||
"event": event_type,
|
||||
"speaking": event_type in ["speech_start", "speaking"],
|
||||
"probability": probability,
|
||||
"timestamp": self.timestamp_ms
|
||||
})
|
||||
|
||||
# Handle speech events
|
||||
if event_type == "speech_start":
|
||||
self.is_speaking = True
|
||||
self.audio_buffer = [audio_np]
|
||||
self.last_partial_duration = 0.0
|
||||
self.last_speech_timestamp = self.timestamp_ms
|
||||
logger.info(f"[STT] User {self.user_id} SPEECH START")
|
||||
|
||||
elif event_type == "speaking":
|
||||
if self.is_speaking:
|
||||
self.audio_buffer.append(audio_np)
|
||||
self.last_speech_timestamp = self.timestamp_ms # Update speech timestamp
|
||||
|
||||
# Transcribe partial every ~1 second for streaming (reduced from 2s)
|
||||
total_samples = sum(len(chunk) for chunk in self.audio_buffer)
|
||||
duration_s = total_samples / 16000
|
||||
|
||||
# More frequent partials for better responsiveness
|
||||
if duration_s >= 1.0:
|
||||
logger.debug(f"Triggering partial transcription at {duration_s:.1f}s")
|
||||
await self._transcribe_partial()
|
||||
# Keep buffer for final transcription, but mark progress
|
||||
self.last_partial_duration = duration_s
|
||||
|
||||
elif event_type == "speech_end":
|
||||
self.is_speaking = False
|
||||
|
||||
logger.info(f"[STT] User {self.user_id} SPEECH END (VAD detected) - transcribing final")
|
||||
|
||||
# Transcribe final
|
||||
await self._transcribe_final()
|
||||
|
||||
# Clear buffer
|
||||
self.audio_buffer = []
|
||||
self.last_partial_duration = 0.0
|
||||
logger.debug(f"User {self.user_id} stopped speaking")
|
||||
|
||||
else:
|
||||
# No VAD event - still accumulate audio if speaking
|
||||
if self.is_speaking:
|
||||
self.audio_buffer.append(audio_np)
|
||||
|
||||
# Check for timeout
|
||||
time_since_speech = self.timestamp_ms - self.last_speech_timestamp
|
||||
|
||||
if time_since_speech >= self.speech_timeout_ms:
|
||||
# Timeout - user probably stopped but VAD didn't detect it
|
||||
logger.warning(f"[STT] User {self.user_id} SPEECH TIMEOUT after {time_since_speech:.0f}ms - forcing finalization")
|
||||
self.is_speaking = False
|
||||
|
||||
# Force final transcription
|
||||
await self._transcribe_final()
|
||||
|
||||
# Clear buffer
|
||||
self.audio_buffer = []
|
||||
self.last_partial_duration = 0.0
|
||||
|
||||
async def _transcribe_partial(self):
|
||||
"""Transcribe accumulated audio and send partial result (no timestamps to save VRAM)."""
|
||||
if not self.audio_buffer:
|
||||
return
|
||||
|
||||
# Concatenate audio
|
||||
audio_full = np.concatenate(self.audio_buffer)
|
||||
|
||||
# Transcribe asynchronously WITHOUT timestamps for partials (saves 1-2GB VRAM)
|
||||
try:
|
||||
result = await parakeet_transcriber.transcribe_async(
|
||||
audio_full,
|
||||
sample_rate=16000,
|
||||
return_timestamps=False # Disable timestamps for partials to reduce VRAM usage
|
||||
)
|
||||
|
||||
# Result is just a string when timestamps=False
|
||||
text = result if isinstance(result, str) else result.get("text", "")
|
||||
|
||||
if text and text != self.last_transcript:
|
||||
self.last_transcript = text
|
||||
|
||||
# Send partial transcript without word tokens (saves memory)
|
||||
await self.websocket.send_json({
|
||||
"type": "partial",
|
||||
"text": text,
|
||||
"words": [], # No word tokens for partials
|
||||
"user_id": self.user_id,
|
||||
"timestamp": self.timestamp_ms
|
||||
})
|
||||
|
||||
logger.info(f"Partial [{self.user_id}]: {text}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Partial transcription failed: {e}", exc_info=True)
|
||||
|
||||
async def _transcribe_final(self):
|
||||
"""Transcribe final accumulated audio with word tokens."""
|
||||
if not self.audio_buffer:
|
||||
return
|
||||
|
||||
# Concatenate all audio
|
||||
audio_full = np.concatenate(self.audio_buffer)
|
||||
|
||||
try:
|
||||
result = await parakeet_transcriber.transcribe_async(
|
||||
audio_full,
|
||||
sample_rate=16000,
|
||||
return_timestamps=True
|
||||
)
|
||||
|
||||
if result and result.get("text"):
|
||||
self.last_transcript = result["text"]
|
||||
|
||||
# Send final transcript with word tokens
|
||||
await self.websocket.send_json({
|
||||
"type": "final",
|
||||
"text": result["text"],
|
||||
"words": result.get("words", []), # Word-level tokens for LLM
|
||||
"user_id": self.user_id,
|
||||
"timestamp": self.timestamp_ms
|
||||
})
|
||||
|
||||
logger.info(f"Final [{self.user_id}]: {result['text']}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Final transcription failed: {e}", exc_info=True)
|
||||
|
||||
async def check_interruption(self, audio_data: bytes) -> bool:
|
||||
"""
|
||||
Check if user is interrupting (for use during Miku's speech).
|
||||
|
||||
Args:
|
||||
audio_data: Raw PCM audio chunk
|
||||
|
||||
Returns:
|
||||
True if interruption detected
|
||||
"""
|
||||
audio_np = np.frombuffer(audio_data, dtype=np.int16)
|
||||
speech_prob, is_speaking = vad_processor.process_chunk(audio_np)
|
||||
|
||||
# Interruption: high probability sustained for threshold duration
|
||||
if speech_prob > 0.7: # Higher threshold for interruption
|
||||
await self.websocket.send_json({
|
||||
"type": "interruption",
|
||||
"probability": speech_prob,
|
||||
"timestamp": self.timestamp_ms
|
||||
})
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Initialize models on server startup."""
|
||||
global vad_processor, parakeet_transcriber
|
||||
|
||||
logger.info("=" * 50)
|
||||
logger.info("Initializing Miku STT Server")
|
||||
logger.info("=" * 50)
|
||||
|
||||
# Initialize VAD (CPU)
|
||||
logger.info("Loading Silero VAD model (CPU)...")
|
||||
vad_processor = VADProcessor(
|
||||
sample_rate=16000,
|
||||
threshold=0.5,
|
||||
min_speech_duration_ms=250, # Conservative - wait 250ms before starting
|
||||
min_silence_duration_ms=300 # Reduced from 500ms - detect silence faster
|
||||
)
|
||||
logger.info("✓ VAD ready")
|
||||
|
||||
# Initialize Parakeet (GPU)
|
||||
logger.info("Loading NVIDIA Parakeet TDT model (GPU)...")
|
||||
parakeet_transcriber = ParakeetTranscriber(
|
||||
model_name="nvidia/parakeet-tdt-0.6b-v3",
|
||||
device="cuda",
|
||||
language="en"
|
||||
)
|
||||
logger.info("✓ Parakeet ready")
|
||||
|
||||
logger.info("=" * 50)
|
||||
logger.info("STT Server ready to accept connections")
|
||||
logger.info("=" * 50)
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Cleanup on server shutdown."""
|
||||
logger.info("Shutting down STT server...")
|
||||
|
||||
if parakeet_transcriber:
|
||||
parakeet_transcriber.cleanup()
|
||||
|
||||
logger.info("STT server shutdown complete")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Health check endpoint."""
|
||||
return {
|
||||
"service": "Miku STT Server",
|
||||
"status": "running",
|
||||
"vad_ready": vad_processor is not None,
|
||||
"whisper_ready": whisper_transcriber is not None,
|
||||
"active_sessions": len(user_sessions)
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Detailed health check."""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"models": {
|
||||
"vad": {
|
||||
"loaded": vad_processor is not None,
|
||||
"device": "cpu"
|
||||
},
|
||||
"whisper": {
|
||||
"loaded": whisper_transcriber is not None,
|
||||
"model": "small",
|
||||
"device": "cuda"
|
||||
}
|
||||
},
|
||||
"sessions": {
|
||||
"active": len(user_sessions),
|
||||
"users": list(user_sessions.keys())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.websocket("/ws/stt/{user_id}")
|
||||
async def websocket_stt(websocket: WebSocket, user_id: str):
|
||||
"""
|
||||
WebSocket endpoint for real-time STT.
|
||||
|
||||
Client sends: Raw PCM audio (int16, 16kHz mono, 20ms chunks)
|
||||
Server sends: JSON events:
|
||||
- {"type": "vad", "event": "speech_start|speaking|speech_end", ...}
|
||||
- {"type": "partial", "text": "...", ...}
|
||||
- {"type": "final", "text": "...", ...}
|
||||
- {"type": "interruption", "probability": 0.xx}
|
||||
"""
|
||||
await websocket.accept()
|
||||
logger.info(f"STT WebSocket connected: user {user_id}")
|
||||
|
||||
# Create session
|
||||
session = UserSTTSession(user_id, websocket)
|
||||
user_sessions[user_id] = session
|
||||
|
||||
try:
|
||||
# Send ready message
|
||||
await websocket.send_json({
|
||||
"type": "ready",
|
||||
"user_id": user_id,
|
||||
"message": "STT session started"
|
||||
})
|
||||
|
||||
# Main loop: receive audio chunks
|
||||
while True:
|
||||
# Receive binary audio data
|
||||
data = await websocket.receive_bytes()
|
||||
|
||||
# Process audio chunk
|
||||
await session.process_audio_chunk(data)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"User {user_id} disconnected")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in STT WebSocket for user {user_id}: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
# Cleanup session
|
||||
if user_id in user_sessions:
|
||||
del user_sessions[user_id]
|
||||
logger.info(f"STT session ended for user {user_id}")
|
||||
|
||||
|
||||
@app.post("/interrupt/check")
|
||||
async def check_interruption(user_id: str):
|
||||
"""
|
||||
Check if user is interrupting (for use during Miku's speech).
|
||||
|
||||
Query param:
|
||||
user_id: Discord user ID
|
||||
|
||||
Returns:
|
||||
{"interrupting": bool, "probability": float}
|
||||
"""
|
||||
session = user_sessions.get(user_id)
|
||||
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="User session not found")
|
||||
|
||||
# Get current VAD state
|
||||
vad_state = vad_processor.get_state()
|
||||
|
||||
return {
|
||||
"interrupting": vad_state["speaking"],
|
||||
"user_id": user_id
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
|
||||
206
stt/test_stt.py
206
stt/test_stt.py
@@ -1,206 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for STT WebSocket server.
|
||||
Sends test audio and receives VAD/transcription events.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import websockets
|
||||
import numpy as np
|
||||
import json
|
||||
import wave
|
||||
|
||||
|
||||
async def test_websocket():
|
||||
"""Test STT WebSocket with generated audio."""
|
||||
|
||||
uri = "ws://localhost:8001/ws/stt/test_user"
|
||||
|
||||
print("🔌 Connecting to STT WebSocket...")
|
||||
|
||||
async with websockets.connect(uri) as websocket:
|
||||
# Wait for ready message
|
||||
ready_msg = await websocket.recv()
|
||||
ready = json.loads(ready_msg)
|
||||
print(f"✅ {ready}")
|
||||
|
||||
# Generate test audio: 2 seconds of 440Hz tone (A note)
|
||||
# This simulates speech-like audio
|
||||
print("\n🎵 Generating test audio (2 seconds, 440Hz tone)...")
|
||||
sample_rate = 16000
|
||||
duration = 2.0
|
||||
frequency = 440 # A4 note
|
||||
|
||||
t = np.linspace(0, duration, int(sample_rate * duration), False)
|
||||
audio = np.sin(frequency * 2 * np.pi * t)
|
||||
|
||||
# Convert to int16
|
||||
audio_int16 = (audio * 32767).astype(np.int16)
|
||||
|
||||
# Send in 20ms chunks (320 samples at 16kHz)
|
||||
chunk_size = 320 # 20ms chunks
|
||||
total_chunks = len(audio_int16) // chunk_size
|
||||
|
||||
print(f"📤 Sending {total_chunks} audio chunks (20ms each)...\n")
|
||||
|
||||
# Send chunks and receive events
|
||||
for i in range(0, len(audio_int16), chunk_size):
|
||||
chunk = audio_int16[i:i+chunk_size]
|
||||
|
||||
# Send audio chunk
|
||||
await websocket.send(chunk.tobytes())
|
||||
|
||||
# Try to receive events (non-blocking)
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
websocket.recv(),
|
||||
timeout=0.01
|
||||
)
|
||||
event = json.loads(response)
|
||||
|
||||
# Print VAD events
|
||||
if event['type'] == 'vad':
|
||||
emoji = "🟢" if event['speaking'] else "⚪"
|
||||
print(f"{emoji} VAD: {event['event']} "
|
||||
f"(prob={event['probability']:.3f}, "
|
||||
f"t={event['timestamp']:.1f}ms)")
|
||||
|
||||
# Print transcription events
|
||||
elif event['type'] == 'partial':
|
||||
print(f"📝 Partial: \"{event['text']}\"")
|
||||
|
||||
elif event['type'] == 'final':
|
||||
print(f"✅ Final: \"{event['text']}\"")
|
||||
|
||||
elif event['type'] == 'interruption':
|
||||
print(f"⚠️ Interruption detected! (prob={event['probability']:.3f})")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
pass # No event yet
|
||||
|
||||
# Small delay between chunks
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
print("\n✅ Test audio sent successfully!")
|
||||
|
||||
# Wait a bit for final transcription
|
||||
print("⏳ Waiting for final transcription...")
|
||||
|
||||
for _ in range(50): # Wait up to 1 second
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
websocket.recv(),
|
||||
timeout=0.02
|
||||
)
|
||||
event = json.loads(response)
|
||||
|
||||
if event['type'] == 'final':
|
||||
print(f"\n✅ FINAL TRANSCRIPT: \"{event['text']}\"")
|
||||
break
|
||||
elif event['type'] == 'vad':
|
||||
emoji = "🟢" if event['speaking'] else "⚪"
|
||||
print(f"{emoji} VAD: {event['event']} (prob={event['probability']:.3f})")
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
print("\n✅ WebSocket test complete!")
|
||||
|
||||
|
||||
async def test_with_sample_audio():
|
||||
"""Test with actual speech audio file (if available)."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
if len(sys.argv) > 1 and os.path.exists(sys.argv[1]):
|
||||
audio_file = sys.argv[1]
|
||||
print(f"📂 Loading audio from: {audio_file}")
|
||||
|
||||
# Load WAV file
|
||||
with wave.open(audio_file, 'rb') as wav:
|
||||
sample_rate = wav.getframerate()
|
||||
n_channels = wav.getnchannels()
|
||||
audio_data = wav.readframes(wav.getnframes())
|
||||
|
||||
# Convert to numpy array
|
||||
audio_np = np.frombuffer(audio_data, dtype=np.int16)
|
||||
|
||||
# If stereo, convert to mono
|
||||
if n_channels == 2:
|
||||
audio_np = audio_np.reshape(-1, 2).mean(axis=1).astype(np.int16)
|
||||
|
||||
# Resample to 16kHz if needed
|
||||
if sample_rate != 16000:
|
||||
print(f"⚠️ Resampling from {sample_rate}Hz to 16000Hz...")
|
||||
import librosa
|
||||
audio_float = audio_np.astype(np.float32) / 32768.0
|
||||
audio_resampled = librosa.resample(
|
||||
audio_float,
|
||||
orig_sr=sample_rate,
|
||||
target_sr=16000
|
||||
)
|
||||
audio_np = (audio_resampled * 32767).astype(np.int16)
|
||||
|
||||
print(f"✅ Audio loaded: {len(audio_np)/16000:.2f} seconds")
|
||||
|
||||
# Send to STT
|
||||
uri = "ws://localhost:8001/ws/stt/test_user"
|
||||
|
||||
async with websockets.connect(uri) as websocket:
|
||||
ready_msg = await websocket.recv()
|
||||
print(f"✅ {json.loads(ready_msg)}")
|
||||
|
||||
# Send in chunks
|
||||
chunk_size = 320 # 20ms at 16kHz
|
||||
|
||||
for i in range(0, len(audio_np), chunk_size):
|
||||
chunk = audio_np[i:i+chunk_size]
|
||||
await websocket.send(chunk.tobytes())
|
||||
|
||||
# Receive events
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
websocket.recv(),
|
||||
timeout=0.01
|
||||
)
|
||||
event = json.loads(response)
|
||||
|
||||
if event['type'] == 'vad':
|
||||
emoji = "🟢" if event['speaking'] else "⚪"
|
||||
print(f"{emoji} VAD: {event['event']} (prob={event['probability']:.3f})")
|
||||
elif event['type'] in ['partial', 'final']:
|
||||
print(f"📝 {event['type'].title()}: \"{event['text']}\"")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
# Wait for final
|
||||
for _ in range(100):
|
||||
try:
|
||||
response = await asyncio.wait_for(websocket.recv(), timeout=0.02)
|
||||
event = json.loads(response)
|
||||
if event['type'] == 'final':
|
||||
print(f"\n✅ FINAL: \"{event['text']}\"")
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
print("=" * 60)
|
||||
print(" Miku STT WebSocket Test")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
print("📁 Testing with audio file...")
|
||||
asyncio.run(test_with_sample_audio())
|
||||
else:
|
||||
print("🎵 Testing with generated tone...")
|
||||
print(" (To test with audio file: python test_stt.py audio.wav)")
|
||||
print()
|
||||
asyncio.run(test_websocket())
|
||||
@@ -1,204 +0,0 @@
|
||||
"""
|
||||
Silero VAD Processor
|
||||
|
||||
Lightweight CPU-based Voice Activity Detection for real-time speech detection.
|
||||
Runs continuously on audio chunks to determine when users are speaking.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Tuple, Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger('vad')
|
||||
|
||||
|
||||
class VADProcessor:
|
||||
"""
|
||||
Voice Activity Detection using Silero VAD model.
|
||||
|
||||
Processes audio chunks and returns speech probability.
|
||||
Conservative settings to avoid cutting off speech.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate: int = 16000,
|
||||
threshold: float = 0.5,
|
||||
min_speech_duration_ms: int = 250,
|
||||
min_silence_duration_ms: int = 500,
|
||||
speech_pad_ms: int = 30
|
||||
):
|
||||
"""
|
||||
Initialize VAD processor.
|
||||
|
||||
Args:
|
||||
sample_rate: Audio sample rate (must be 8000 or 16000)
|
||||
threshold: Speech probability threshold (0.0-1.0)
|
||||
min_speech_duration_ms: Minimum speech duration to trigger (conservative)
|
||||
min_silence_duration_ms: Minimum silence to end speech (conservative)
|
||||
speech_pad_ms: Padding around speech segments
|
||||
"""
|
||||
self.sample_rate = sample_rate
|
||||
self.threshold = threshold
|
||||
self.min_speech_duration_ms = min_speech_duration_ms
|
||||
self.min_silence_duration_ms = min_silence_duration_ms
|
||||
self.speech_pad_ms = speech_pad_ms
|
||||
|
||||
# Load Silero VAD model (CPU only)
|
||||
logger.info("Loading Silero VAD model (CPU)...")
|
||||
self.model, utils = torch.hub.load(
|
||||
repo_or_dir='snakers4/silero-vad',
|
||||
model='silero_vad',
|
||||
force_reload=False,
|
||||
onnx=False # Use PyTorch model
|
||||
)
|
||||
|
||||
# Extract utility functions
|
||||
(self.get_speech_timestamps,
|
||||
self.save_audio,
|
||||
self.read_audio,
|
||||
self.VADIterator,
|
||||
self.collect_chunks) = utils
|
||||
|
||||
# State tracking
|
||||
self.speaking = False
|
||||
self.speech_start_time = None
|
||||
self.silence_start_time = None
|
||||
self.audio_buffer = []
|
||||
|
||||
# Chunk buffer for VAD (Silero needs at least 512 samples)
|
||||
self.vad_buffer = []
|
||||
self.min_vad_samples = 512 # Minimum samples for VAD processing
|
||||
|
||||
logger.info(f"VAD initialized: threshold={threshold}, "
|
||||
f"min_speech={min_speech_duration_ms}ms, "
|
||||
f"min_silence={min_silence_duration_ms}ms")
|
||||
|
||||
def process_chunk(self, audio_chunk: np.ndarray) -> Tuple[float, bool]:
|
||||
"""
|
||||
Process single audio chunk and return speech probability.
|
||||
Buffers small chunks to meet VAD minimum size requirement.
|
||||
|
||||
Args:
|
||||
audio_chunk: Audio data as numpy array (int16 or float32)
|
||||
|
||||
Returns:
|
||||
(speech_probability, is_speaking): Probability and current speaking state
|
||||
"""
|
||||
# Convert to float32 if needed
|
||||
if audio_chunk.dtype == np.int16:
|
||||
audio_chunk = audio_chunk.astype(np.float32) / 32768.0
|
||||
|
||||
# Add to buffer
|
||||
self.vad_buffer.append(audio_chunk)
|
||||
|
||||
# Check if we have enough samples
|
||||
total_samples = sum(len(chunk) for chunk in self.vad_buffer)
|
||||
|
||||
if total_samples < self.min_vad_samples:
|
||||
# Not enough samples yet, return neutral probability
|
||||
return 0.0, False
|
||||
|
||||
# Concatenate buffer
|
||||
audio_full = np.concatenate(self.vad_buffer)
|
||||
|
||||
# Process with VAD
|
||||
audio_tensor = torch.from_numpy(audio_full)
|
||||
|
||||
with torch.no_grad():
|
||||
speech_prob = self.model(audio_tensor, self.sample_rate).item()
|
||||
|
||||
# Clear buffer after processing
|
||||
self.vad_buffer = []
|
||||
|
||||
# Update speaking state based on probability
|
||||
is_speaking = speech_prob > self.threshold
|
||||
|
||||
return speech_prob, is_speaking
|
||||
|
||||
def detect_speech_segment(
|
||||
self,
|
||||
audio_chunk: np.ndarray,
|
||||
timestamp_ms: float
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Process chunk and detect speech start/end events.
|
||||
|
||||
Args:
|
||||
audio_chunk: Audio data
|
||||
timestamp_ms: Current timestamp in milliseconds
|
||||
|
||||
Returns:
|
||||
Event dict or None:
|
||||
- {"event": "speech_start", "timestamp": float, "probability": float}
|
||||
- {"event": "speech_end", "timestamp": float, "probability": float}
|
||||
- {"event": "speaking", "probability": float} # Ongoing speech
|
||||
"""
|
||||
speech_prob, is_speaking = self.process_chunk(audio_chunk)
|
||||
|
||||
# Speech started
|
||||
if is_speaking and not self.speaking:
|
||||
if self.speech_start_time is None:
|
||||
self.speech_start_time = timestamp_ms
|
||||
|
||||
# Check if speech duration exceeds minimum
|
||||
speech_duration = timestamp_ms - self.speech_start_time
|
||||
if speech_duration >= self.min_speech_duration_ms:
|
||||
self.speaking = True
|
||||
self.silence_start_time = None
|
||||
logger.debug(f"Speech started at {timestamp_ms}ms, prob={speech_prob:.3f}")
|
||||
return {
|
||||
"event": "speech_start",
|
||||
"timestamp": timestamp_ms,
|
||||
"probability": speech_prob
|
||||
}
|
||||
|
||||
# Speech ongoing
|
||||
elif is_speaking and self.speaking:
|
||||
self.silence_start_time = None # Reset silence timer
|
||||
return {
|
||||
"event": "speaking",
|
||||
"probability": speech_prob,
|
||||
"timestamp": timestamp_ms
|
||||
}
|
||||
|
||||
# Silence detected during speech
|
||||
elif not is_speaking and self.speaking:
|
||||
if self.silence_start_time is None:
|
||||
self.silence_start_time = timestamp_ms
|
||||
|
||||
# Check if silence duration exceeds minimum
|
||||
silence_duration = timestamp_ms - self.silence_start_time
|
||||
if silence_duration >= self.min_silence_duration_ms:
|
||||
self.speaking = False
|
||||
self.speech_start_time = None
|
||||
logger.debug(f"Speech ended at {timestamp_ms}ms, prob={speech_prob:.3f}")
|
||||
return {
|
||||
"event": "speech_end",
|
||||
"timestamp": timestamp_ms,
|
||||
"probability": speech_prob
|
||||
}
|
||||
|
||||
# No speech or insufficient duration
|
||||
else:
|
||||
if not is_speaking:
|
||||
self.speech_start_time = None
|
||||
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
"""Reset VAD state."""
|
||||
self.speaking = False
|
||||
self.speech_start_time = None
|
||||
self.silence_start_time = None
|
||||
self.audio_buffer.clear()
|
||||
logger.debug("VAD state reset")
|
||||
|
||||
def get_state(self) -> dict:
|
||||
"""Get current VAD state."""
|
||||
return {
|
||||
"speaking": self.speaking,
|
||||
"speech_start_time": self.speech_start_time,
|
||||
"silence_start_time": self.silence_start_time
|
||||
}
|
||||
@@ -1,193 +0,0 @@
|
||||
"""
|
||||
Faster-Whisper Transcriber
|
||||
|
||||
GPU-accelerated speech-to-text using faster-whisper (CTranslate2).
|
||||
Supports streaming transcription with partial results.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from faster_whisper import WhisperModel
|
||||
from typing import Iterator, Optional, List
|
||||
import logging
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
logger = logging.getLogger('whisper')
|
||||
|
||||
|
||||
class WhisperTranscriber:
|
||||
"""
|
||||
Faster-Whisper based transcription with streaming support.
|
||||
|
||||
Runs on GPU (GTX 1660) with small model for balance of speed/quality.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_size: str = "small",
|
||||
device: str = "cuda",
|
||||
compute_type: str = "float16",
|
||||
language: str = "en",
|
||||
beam_size: int = 5
|
||||
):
|
||||
"""
|
||||
Initialize Whisper transcriber.
|
||||
|
||||
Args:
|
||||
model_size: Model size (tiny, base, small, medium, large)
|
||||
device: Device to run on (cuda or cpu)
|
||||
compute_type: Compute precision (float16, int8, int8_float16)
|
||||
language: Language code for transcription
|
||||
beam_size: Beam search size (higher = better quality, slower)
|
||||
"""
|
||||
self.model_size = model_size
|
||||
self.device = device
|
||||
self.compute_type = compute_type
|
||||
self.language = language
|
||||
self.beam_size = beam_size
|
||||
|
||||
logger.info(f"Loading Faster-Whisper model: {model_size} on {device}...")
|
||||
|
||||
# Load model
|
||||
self.model = WhisperModel(
|
||||
model_size,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
download_root="/models"
|
||||
)
|
||||
|
||||
# Thread pool for blocking transcription calls
|
||||
self.executor = ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
logger.info(f"Whisper model loaded: {model_size} ({compute_type})")
|
||||
|
||||
async def transcribe_async(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
sample_rate: int = 16000,
|
||||
initial_prompt: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Transcribe audio asynchronously (non-blocking).
|
||||
|
||||
Args:
|
||||
audio: Audio data as numpy array (float32)
|
||||
sample_rate: Audio sample rate
|
||||
initial_prompt: Optional prompt to guide transcription
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Run transcription in thread pool to avoid blocking
|
||||
result = await loop.run_in_executor(
|
||||
self.executor,
|
||||
self._transcribe_blocking,
|
||||
audio,
|
||||
sample_rate,
|
||||
initial_prompt
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _transcribe_blocking(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
sample_rate: int,
|
||||
initial_prompt: Optional[str]
|
||||
) -> str:
|
||||
"""
|
||||
Blocking transcription call (runs in thread pool).
|
||||
"""
|
||||
# Convert to float32 if needed
|
||||
if audio.dtype != np.float32:
|
||||
audio = audio.astype(np.float32) / 32768.0
|
||||
|
||||
# Transcribe
|
||||
segments, info = self.model.transcribe(
|
||||
audio,
|
||||
language=self.language,
|
||||
beam_size=self.beam_size,
|
||||
initial_prompt=initial_prompt,
|
||||
vad_filter=False, # We handle VAD separately
|
||||
word_timestamps=False # Can enable for word-level timing
|
||||
)
|
||||
|
||||
# Collect all segments
|
||||
text_parts = []
|
||||
for segment in segments:
|
||||
text_parts.append(segment.text.strip())
|
||||
|
||||
full_text = " ".join(text_parts).strip()
|
||||
|
||||
logger.debug(f"Transcribed: '{full_text}' (language: {info.language}, "
|
||||
f"probability: {info.language_probability:.2f})")
|
||||
|
||||
return full_text
|
||||
|
||||
async def transcribe_streaming(
|
||||
self,
|
||||
audio_stream: Iterator[np.ndarray],
|
||||
sample_rate: int = 16000,
|
||||
chunk_duration_s: float = 2.0
|
||||
) -> Iterator[dict]:
|
||||
"""
|
||||
Transcribe audio stream with partial results.
|
||||
|
||||
Args:
|
||||
audio_stream: Iterator yielding audio chunks
|
||||
sample_rate: Audio sample rate
|
||||
chunk_duration_s: Duration of each chunk to transcribe
|
||||
|
||||
Yields:
|
||||
{"type": "partial", "text": "partial transcript"}
|
||||
{"type": "final", "text": "complete transcript"}
|
||||
"""
|
||||
accumulated_audio = []
|
||||
chunk_samples = int(chunk_duration_s * sample_rate)
|
||||
|
||||
async for audio_chunk in audio_stream:
|
||||
accumulated_audio.append(audio_chunk)
|
||||
|
||||
# Check if we have enough audio for transcription
|
||||
total_samples = sum(len(chunk) for chunk in accumulated_audio)
|
||||
|
||||
if total_samples >= chunk_samples:
|
||||
# Concatenate accumulated audio
|
||||
audio_data = np.concatenate(accumulated_audio)
|
||||
|
||||
# Transcribe current accumulated audio
|
||||
text = await self.transcribe_async(audio_data, sample_rate)
|
||||
|
||||
if text:
|
||||
yield {
|
||||
"type": "partial",
|
||||
"text": text,
|
||||
"duration": total_samples / sample_rate
|
||||
}
|
||||
|
||||
# Final transcription of remaining audio
|
||||
if accumulated_audio:
|
||||
audio_data = np.concatenate(accumulated_audio)
|
||||
text = await self.transcribe_async(audio_data, sample_rate)
|
||||
|
||||
if text:
|
||||
yield {
|
||||
"type": "final",
|
||||
"text": text,
|
||||
"duration": len(audio_data) / sample_rate
|
||||
}
|
||||
|
||||
def get_supported_languages(self) -> List[str]:
|
||||
"""Get list of supported language codes."""
|
||||
return [
|
||||
"en", "zh", "de", "es", "ru", "ko", "fr", "ja", "pt", "tr",
|
||||
"pl", "ca", "nl", "ar", "sv", "it", "id", "hi", "fi", "vi",
|
||||
"he", "uk", "el", "ms", "cs", "ro", "da", "hu", "ta", "no"
|
||||
]
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources."""
|
||||
self.executor.shutdown(wait=True)
|
||||
logger.info("Whisper transcriber cleaned up")
|
||||
Reference in New Issue
Block a user