115 lines
3.2 KiB
Python
115 lines
3.2 KiB
Python
"""
|
|
Test offline ASR pipeline with onnx-asr
|
|
"""
|
|
import soundfile as sf
|
|
import numpy as np
|
|
import sys
|
|
import argparse
|
|
from pathlib import Path
|
|
from asr.asr_pipeline import ASRPipeline
|
|
|
|
|
|
def test_transcription(audio_file: str, use_vad: bool = False, quantization: str = None):
|
|
"""
|
|
Test ASR transcription on an audio file.
|
|
|
|
Args:
|
|
audio_file: Path to audio file
|
|
use_vad: Whether to use VAD
|
|
quantization: Optional quantization (e.g., "int8")
|
|
"""
|
|
print(f"\n{'='*80}")
|
|
print(f"Testing ASR Pipeline with onnx-asr")
|
|
print(f"{'='*80}")
|
|
print(f"Audio file: {audio_file}")
|
|
print(f"Use VAD: {use_vad}")
|
|
print(f"Quantization: {quantization}")
|
|
print(f"{'='*80}\n")
|
|
|
|
# Initialize pipeline
|
|
print("Initializing ASR pipeline...")
|
|
pipeline = ASRPipeline(
|
|
model_name="nemo-parakeet-tdt-0.6b-v3",
|
|
quantization=quantization,
|
|
use_vad=use_vad,
|
|
)
|
|
print("Pipeline initialized successfully!\n")
|
|
|
|
# Read audio file
|
|
print(f"Reading audio file: {audio_file}")
|
|
audio, sr = sf.read(audio_file, dtype="float32")
|
|
print(f"Sample rate: {sr} Hz")
|
|
print(f"Audio shape: {audio.shape}")
|
|
print(f"Audio duration: {len(audio) / sr:.2f} seconds")
|
|
|
|
# Ensure mono
|
|
if audio.ndim > 1:
|
|
print("Converting stereo to mono...")
|
|
audio = audio[:, 0]
|
|
|
|
# Verify sample rate
|
|
if sr != 16000:
|
|
print(f"WARNING: Sample rate is {sr} Hz, expected 16000 Hz")
|
|
print("Consider resampling the audio file")
|
|
|
|
print(f"\n{'='*80}")
|
|
print("Transcribing...")
|
|
print(f"{'='*80}\n")
|
|
|
|
# Transcribe
|
|
result = pipeline.transcribe(audio, sample_rate=sr)
|
|
|
|
# Display results
|
|
if use_vad and isinstance(result, list):
|
|
print("TRANSCRIPTION (with VAD):")
|
|
print("-" * 80)
|
|
for i, segment in enumerate(result, 1):
|
|
print(f"Segment {i}: {segment}")
|
|
print("-" * 80)
|
|
else:
|
|
print("TRANSCRIPTION:")
|
|
print("-" * 80)
|
|
print(result)
|
|
print("-" * 80)
|
|
|
|
# Audio statistics
|
|
print(f"\nAUDIO STATISTICS:")
|
|
print(f" dtype: {audio.dtype}")
|
|
print(f" min: {audio.min():.6f}")
|
|
print(f" max: {audio.max():.6f}")
|
|
print(f" mean: {audio.mean():.6f}")
|
|
print(f" std: {audio.std():.6f}")
|
|
|
|
print(f"\n{'='*80}")
|
|
print("Test completed successfully!")
|
|
print(f"{'='*80}\n")
|
|
|
|
return result
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Test offline ASR transcription")
|
|
parser.add_argument("audio_file", help="Path to audio file (WAV format)")
|
|
parser.add_argument("--use-vad", action="store_true", help="Enable VAD")
|
|
parser.add_argument("--quantization", default=None, choices=["int8", "fp16"],
|
|
help="Model quantization")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Check if file exists
|
|
if not Path(args.audio_file).exists():
|
|
print(f"ERROR: Audio file not found: {args.audio_file}")
|
|
sys.exit(1)
|
|
|
|
try:
|
|
test_transcription(args.audio_file, args.use_vad, args.quantization)
|
|
except Exception as e:
|
|
print(f"\nERROR: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|