205 lines
7.0 KiB
Python
205 lines
7.0 KiB
Python
"""
|
|
Silero VAD Processor
|
|
|
|
Lightweight CPU-based Voice Activity Detection for real-time speech detection.
|
|
Runs continuously on audio chunks to determine when users are speaking.
|
|
"""
|
|
|
|
import torch
|
|
import numpy as np
|
|
from typing import Tuple, Optional
|
|
import logging
|
|
|
|
logger = logging.getLogger('vad')
|
|
|
|
|
|
class VADProcessor:
|
|
"""
|
|
Voice Activity Detection using Silero VAD model.
|
|
|
|
Processes audio chunks and returns speech probability.
|
|
Conservative settings to avoid cutting off speech.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
sample_rate: int = 16000,
|
|
threshold: float = 0.5,
|
|
min_speech_duration_ms: int = 250,
|
|
min_silence_duration_ms: int = 500,
|
|
speech_pad_ms: int = 30
|
|
):
|
|
"""
|
|
Initialize VAD processor.
|
|
|
|
Args:
|
|
sample_rate: Audio sample rate (must be 8000 or 16000)
|
|
threshold: Speech probability threshold (0.0-1.0)
|
|
min_speech_duration_ms: Minimum speech duration to trigger (conservative)
|
|
min_silence_duration_ms: Minimum silence to end speech (conservative)
|
|
speech_pad_ms: Padding around speech segments
|
|
"""
|
|
self.sample_rate = sample_rate
|
|
self.threshold = threshold
|
|
self.min_speech_duration_ms = min_speech_duration_ms
|
|
self.min_silence_duration_ms = min_silence_duration_ms
|
|
self.speech_pad_ms = speech_pad_ms
|
|
|
|
# Load Silero VAD model (CPU only)
|
|
logger.info("Loading Silero VAD model (CPU)...")
|
|
self.model, utils = torch.hub.load(
|
|
repo_or_dir='snakers4/silero-vad',
|
|
model='silero_vad',
|
|
force_reload=False,
|
|
onnx=False # Use PyTorch model
|
|
)
|
|
|
|
# Extract utility functions
|
|
(self.get_speech_timestamps,
|
|
self.save_audio,
|
|
self.read_audio,
|
|
self.VADIterator,
|
|
self.collect_chunks) = utils
|
|
|
|
# State tracking
|
|
self.speaking = False
|
|
self.speech_start_time = None
|
|
self.silence_start_time = None
|
|
self.audio_buffer = []
|
|
|
|
# Chunk buffer for VAD (Silero needs at least 512 samples)
|
|
self.vad_buffer = []
|
|
self.min_vad_samples = 512 # Minimum samples for VAD processing
|
|
|
|
logger.info(f"VAD initialized: threshold={threshold}, "
|
|
f"min_speech={min_speech_duration_ms}ms, "
|
|
f"min_silence={min_silence_duration_ms}ms")
|
|
|
|
def process_chunk(self, audio_chunk: np.ndarray) -> Tuple[float, bool]:
|
|
"""
|
|
Process single audio chunk and return speech probability.
|
|
Buffers small chunks to meet VAD minimum size requirement.
|
|
|
|
Args:
|
|
audio_chunk: Audio data as numpy array (int16 or float32)
|
|
|
|
Returns:
|
|
(speech_probability, is_speaking): Probability and current speaking state
|
|
"""
|
|
# Convert to float32 if needed
|
|
if audio_chunk.dtype == np.int16:
|
|
audio_chunk = audio_chunk.astype(np.float32) / 32768.0
|
|
|
|
# Add to buffer
|
|
self.vad_buffer.append(audio_chunk)
|
|
|
|
# Check if we have enough samples
|
|
total_samples = sum(len(chunk) for chunk in self.vad_buffer)
|
|
|
|
if total_samples < self.min_vad_samples:
|
|
# Not enough samples yet, return neutral probability
|
|
return 0.0, False
|
|
|
|
# Concatenate buffer
|
|
audio_full = np.concatenate(self.vad_buffer)
|
|
|
|
# Process with VAD
|
|
audio_tensor = torch.from_numpy(audio_full)
|
|
|
|
with torch.no_grad():
|
|
speech_prob = self.model(audio_tensor, self.sample_rate).item()
|
|
|
|
# Clear buffer after processing
|
|
self.vad_buffer = []
|
|
|
|
# Update speaking state based on probability
|
|
is_speaking = speech_prob > self.threshold
|
|
|
|
return speech_prob, is_speaking
|
|
|
|
def detect_speech_segment(
|
|
self,
|
|
audio_chunk: np.ndarray,
|
|
timestamp_ms: float
|
|
) -> Optional[dict]:
|
|
"""
|
|
Process chunk and detect speech start/end events.
|
|
|
|
Args:
|
|
audio_chunk: Audio data
|
|
timestamp_ms: Current timestamp in milliseconds
|
|
|
|
Returns:
|
|
Event dict or None:
|
|
- {"event": "speech_start", "timestamp": float, "probability": float}
|
|
- {"event": "speech_end", "timestamp": float, "probability": float}
|
|
- {"event": "speaking", "probability": float} # Ongoing speech
|
|
"""
|
|
speech_prob, is_speaking = self.process_chunk(audio_chunk)
|
|
|
|
# Speech started
|
|
if is_speaking and not self.speaking:
|
|
if self.speech_start_time is None:
|
|
self.speech_start_time = timestamp_ms
|
|
|
|
# Check if speech duration exceeds minimum
|
|
speech_duration = timestamp_ms - self.speech_start_time
|
|
if speech_duration >= self.min_speech_duration_ms:
|
|
self.speaking = True
|
|
self.silence_start_time = None
|
|
logger.debug(f"Speech started at {timestamp_ms}ms, prob={speech_prob:.3f}")
|
|
return {
|
|
"event": "speech_start",
|
|
"timestamp": timestamp_ms,
|
|
"probability": speech_prob
|
|
}
|
|
|
|
# Speech ongoing
|
|
elif is_speaking and self.speaking:
|
|
self.silence_start_time = None # Reset silence timer
|
|
return {
|
|
"event": "speaking",
|
|
"probability": speech_prob,
|
|
"timestamp": timestamp_ms
|
|
}
|
|
|
|
# Silence detected during speech
|
|
elif not is_speaking and self.speaking:
|
|
if self.silence_start_time is None:
|
|
self.silence_start_time = timestamp_ms
|
|
|
|
# Check if silence duration exceeds minimum
|
|
silence_duration = timestamp_ms - self.silence_start_time
|
|
if silence_duration >= self.min_silence_duration_ms:
|
|
self.speaking = False
|
|
self.speech_start_time = None
|
|
logger.debug(f"Speech ended at {timestamp_ms}ms, prob={speech_prob:.3f}")
|
|
return {
|
|
"event": "speech_end",
|
|
"timestamp": timestamp_ms,
|
|
"probability": speech_prob
|
|
}
|
|
|
|
# No speech or insufficient duration
|
|
else:
|
|
if not is_speaking:
|
|
self.speech_start_time = None
|
|
|
|
return None
|
|
|
|
def reset(self):
|
|
"""Reset VAD state."""
|
|
self.speaking = False
|
|
self.speech_start_time = None
|
|
self.silence_start_time = None
|
|
self.audio_buffer.clear()
|
|
logger.debug("VAD state reset")
|
|
|
|
def get_state(self) -> dict:
|
|
"""Get current VAD state."""
|
|
return {
|
|
"speaking": self.speaking,
|
|
"speech_start_time": self.speech_start_time,
|
|
"silence_start_time": self.silence_start_time
|
|
}
|