Implemented experimental real production ready voice chat, relegated old flow to voice debug mode. New Web UI panel for Voice Chat.

This commit is contained in:
2026-01-20 23:06:17 +02:00
parent 362108f4b0
commit 2934efba22
31 changed files with 5408 additions and 357 deletions

View File

@@ -49,6 +49,15 @@ class ParakeetTranscriber:
logger.info(f"Loading Parakeet model: {model_name} on {device}...")
# Set PyTorch memory allocator settings for better memory management
if device == "cuda":
# Enable expandable segments to reduce fragmentation
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
# Clear cache before loading model
torch.cuda.empty_cache()
# Load model via NeMo from HuggingFace
self.model = EncDecRNNTBPEModel.from_pretrained(
model_name=model_name,
@@ -58,6 +67,11 @@ class ParakeetTranscriber:
self.model.eval()
if device == "cuda":
self.model = self.model.cuda()
# Enable memory efficient attention if available
try:
self.model.encoder.use_memory_efficient_attention = True
except:
pass
# Thread pool for blocking transcription calls
self.executor = ThreadPoolExecutor(max_workers=2)
@@ -119,7 +133,7 @@ class ParakeetTranscriber:
# Transcribe using NeMo model
with torch.no_grad():
# Convert to tensor
# Convert to tensor and keep on GPU to avoid CPU/GPU bouncing
audio_signal = torch.from_numpy(audio).unsqueeze(0)
audio_signal_len = torch.tensor([len(audio)])
@@ -127,12 +141,14 @@ class ParakeetTranscriber:
audio_signal = audio_signal.cuda()
audio_signal_len = audio_signal_len.cuda()
# Get transcription with timestamps
# NeMo returns list of Hypothesis objects when timestamps=True
# Get transcription
# NeMo returns list of Hypothesis objects
# Note: timestamps=True causes significant VRAM usage (~1-2GB extra)
# Only enable for final transcriptions, not streaming partials
transcriptions = self.model.transcribe(
audio=[audio_signal.squeeze(0).cpu().numpy()],
audio=[audio], # Pass NumPy array directly (NeMo handles it efficiently)
batch_size=1,
timestamps=True # Enable timestamps to get word-level data
timestamps=return_timestamps # Only use timestamps when explicitly requested
)
# Extract text from Hypothesis object
@@ -144,9 +160,9 @@ class ParakeetTranscriber:
# Hypothesis object has .text attribute
text = hypothesis.text.strip() if hasattr(hypothesis, 'text') else str(hypothesis).strip()
# Extract word-level timestamps if available
# Extract word-level timestamps if available and requested
words = []
if hasattr(hypothesis, 'timestamp') and hypothesis.timestamp:
if return_timestamps and hasattr(hypothesis, 'timestamp') and hypothesis.timestamp:
# timestamp is a dict with 'word' key containing list of word timestamps
word_timestamps = hypothesis.timestamp.get('word', [])
for word_info in word_timestamps:
@@ -165,6 +181,10 @@ class ParakeetTranscriber:
}
else:
return text
# Note: We do NOT call torch.cuda.empty_cache() here
# That breaks PyTorch's memory allocator and causes fragmentation
# Let PyTorch manage its own memory pool
async def transcribe_streaming(
self,

View File

@@ -22,6 +22,7 @@ silero-vad==5.1.2
huggingface-hub>=0.30.0,<1.0
nemo_toolkit[asr]==2.4.0
omegaconf==2.3.0
cuda-python>=12.3 # Enable CUDA graphs for faster decoding
# Utilities
python-multipart==0.0.20

View File

@@ -51,6 +51,9 @@ class UserSTTSession:
self.timestamp_ms = 0.0
self.transcript_buffer = []
self.last_transcript = ""
self.last_partial_duration = 0.0 # Track when we last sent a partial
self.last_speech_timestamp = 0.0 # Track last time we detected speech
self.speech_timeout_ms = 3000 # Force finalization after 3s of no new speech
logger.info(f"Created STT session for user {user_id}")
@@ -75,6 +78,8 @@ class UserSTTSession:
event_type = vad_event["event"]
probability = vad_event["probability"]
logger.debug(f"VAD event for user {self.user_id}: {event_type} (prob={probability:.3f})")
# Send VAD event to client
await self.websocket.send_json({
"type": "vad",
@@ -88,63 +93,91 @@ class UserSTTSession:
if event_type == "speech_start":
self.is_speaking = True
self.audio_buffer = [audio_np]
logger.debug(f"User {self.user_id} started speaking")
self.last_partial_duration = 0.0
self.last_speech_timestamp = self.timestamp_ms
logger.info(f"[STT] User {self.user_id} SPEECH START")
elif event_type == "speaking":
if self.is_speaking:
self.audio_buffer.append(audio_np)
self.last_speech_timestamp = self.timestamp_ms # Update speech timestamp
# Transcribe partial every ~2 seconds for streaming
# Transcribe partial every ~1 second for streaming (reduced from 2s)
total_samples = sum(len(chunk) for chunk in self.audio_buffer)
duration_s = total_samples / 16000
if duration_s >= 2.0:
# More frequent partials for better responsiveness
if duration_s >= 1.0:
logger.debug(f"Triggering partial transcription at {duration_s:.1f}s")
await self._transcribe_partial()
# Keep buffer for final transcription, but mark progress
self.last_partial_duration = duration_s
elif event_type == "speech_end":
self.is_speaking = False
logger.info(f"[STT] User {self.user_id} SPEECH END (VAD detected) - transcribing final")
# Transcribe final
await self._transcribe_final()
# Clear buffer
self.audio_buffer = []
self.last_partial_duration = 0.0
logger.debug(f"User {self.user_id} stopped speaking")
else:
# Still accumulate audio if speaking
# No VAD event - still accumulate audio if speaking
if self.is_speaking:
self.audio_buffer.append(audio_np)
# Check for timeout
time_since_speech = self.timestamp_ms - self.last_speech_timestamp
if time_since_speech >= self.speech_timeout_ms:
# Timeout - user probably stopped but VAD didn't detect it
logger.warning(f"[STT] User {self.user_id} SPEECH TIMEOUT after {time_since_speech:.0f}ms - forcing finalization")
self.is_speaking = False
# Force final transcription
await self._transcribe_final()
# Clear buffer
self.audio_buffer = []
self.last_partial_duration = 0.0
async def _transcribe_partial(self):
"""Transcribe accumulated audio and send partial result with word tokens."""
"""Transcribe accumulated audio and send partial result (no timestamps to save VRAM)."""
if not self.audio_buffer:
return
# Concatenate audio
audio_full = np.concatenate(self.audio_buffer)
# Transcribe asynchronously with word-level timestamps
# Transcribe asynchronously WITHOUT timestamps for partials (saves 1-2GB VRAM)
try:
result = await parakeet_transcriber.transcribe_async(
audio_full,
sample_rate=16000,
return_timestamps=True
return_timestamps=False # Disable timestamps for partials to reduce VRAM usage
)
if result and result.get("text") and result["text"] != self.last_transcript:
self.last_transcript = result["text"]
# Result is just a string when timestamps=False
text = result if isinstance(result, str) else result.get("text", "")
if text and text != self.last_transcript:
self.last_transcript = text
# Send partial transcript with word tokens for LLM pre-computation
# Send partial transcript without word tokens (saves memory)
await self.websocket.send_json({
"type": "partial",
"text": result["text"],
"words": result.get("words", []), # Word-level tokens
"text": text,
"words": [], # No word tokens for partials
"user_id": self.user_id,
"timestamp": self.timestamp_ms
})
logger.info(f"Partial [{self.user_id}]: {result['text']}")
logger.info(f"Partial [{self.user_id}]: {text}")
except Exception as e:
logger.error(f"Partial transcription failed: {e}", exc_info=True)
@@ -220,8 +253,8 @@ async def startup_event():
vad_processor = VADProcessor(
sample_rate=16000,
threshold=0.5,
min_speech_duration_ms=250, # Conservative
min_silence_duration_ms=500 # Conservative
min_speech_duration_ms=250, # Conservative - wait 250ms before starting
min_silence_duration_ms=300 # Reduced from 500ms - detect silence faster
)
logger.info("✓ VAD ready")