asmud's picture
Add ONNX quantized model with example and documentation
fe9120d
#!/usr/bin/env python3
"""
Example script demonstrating how to use the Cahya Whisper Medium ONNX model
for Indonesian speech recognition.
This script shows how to:
1. Load the quantized ONNX model (encoder + decoder)
2. Process audio files for inference
3. Generate transcriptions
Requirements:
- onnxruntime
- transformers
- librosa
- numpy
"""
import os
import json
import numpy as np
import librosa
import onnxruntime as ort
from transformers import WhisperProcessor
from pathlib import Path
import argparse
import time
class CahyaWhisperONNX:
"""ONNX inference wrapper for Cahya Whisper Medium Indonesian model"""
def __init__(self, model_dir="./"):
"""
Initialize the ONNX Whisper model
Args:
model_dir (str): Directory containing the ONNX model files
"""
self.model_dir = Path(model_dir)
self.encoder_path = self.model_dir / "encoder_model_quantized.onnx"
self.decoder_path = self.model_dir / "decoder_model_quantized.onnx"
self.config_path = self.model_dir / "config.json"
# Validate model files exist
if not self.encoder_path.exists():
raise FileNotFoundError(f"Encoder model not found: {self.encoder_path}")
if not self.decoder_path.exists():
raise FileNotFoundError(f"Decoder model not found: {self.decoder_path}")
if not self.config_path.exists():
raise FileNotFoundError(f"Config file not found: {self.config_path}")
# Load ONNX models with quantization support
print("Loading ONNX models...")
# Configure session options for quantized models
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# Try different execution providers for quantized models
providers = ['CPUExecutionProvider']
try:
self.encoder_session = ort.InferenceSession(
str(self.encoder_path),
sess_options=session_options,
providers=providers
)
print("✓ Encoder model loaded successfully")
except Exception as e:
print(f"✗ Failed to load encoder: {e}")
raise
try:
self.decoder_session = ort.InferenceSession(
str(self.decoder_path),
sess_options=session_options,
providers=providers
)
print("✓ Decoder model loaded successfully")
except Exception as e:
print(f"✗ Failed to load decoder: {e}")
raise
# Load processor for tokenization (using base Whisper processor)
print("Loading processor...")
self.processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
# Load model config
with open(self.config_path, 'r') as f:
self.config = json.load(f)
print("Model loaded successfully!")
print(f"Model type: {self.config.get('model_type', 'whisper')}")
print(f"Vocab size: {self.config.get('vocab_size', 'unknown')}")
def preprocess_audio(self, audio_path, max_duration=30.0):
"""
Preprocess audio file for inference
Args:
audio_path (str): Path to audio file
max_duration (float): Maximum audio duration in seconds
Returns:
np.ndarray: Preprocessed audio features
"""
# Load audio
audio, sr = librosa.load(audio_path, sr=16000)
# Trim to max duration
max_samples = int(max_duration * 16000)
if len(audio) > max_samples:
audio = audio[:max_samples]
print(f"Audio trimmed to {max_duration} seconds")
print(f"Audio duration: {len(audio) / 16000:.2f} seconds")
return audio
def transcribe(self, audio_input, max_new_tokens=128):
"""
Transcribe audio to text
Args:
audio_input: Audio array or path to audio file
max_new_tokens (int): Maximum number of tokens to generate
Returns:
str: Transcribed text
"""
# Handle both file path and audio array inputs
if isinstance(audio_input, str):
audio_array = self.preprocess_audio(audio_input)
else:
audio_array = audio_input
# Prepare input features
input_features = self.processor(
audio_array,
sampling_rate=16000,
return_tensors="np"
).input_features
print(f"Input features shape: {input_features.shape}")
# Encoder forward pass
print("Running encoder...")
start_time = time.time()
encoder_outputs = self.encoder_session.run(
None,
{"input_features": input_features}
)[0]
encoder_time = time.time() - start_time
print(f"Encoder inference time: {encoder_time:.3f}s")
print(f"Encoder output shape: {encoder_outputs.shape}")
# Initialize decoder with start token
decoder_input_ids = np.array([[self.config["decoder_start_token_id"]]], dtype=np.int64)
generated_tokens = [self.config["decoder_start_token_id"]]
print("Running decoder...")
decoder_start_time = time.time()
# Simple greedy decoding (for demonstration)
for step in range(max_new_tokens):
# Decoder forward pass
decoder_outputs = self.decoder_session.run(
None,
{
"input_ids": decoder_input_ids,
"encoder_hidden_states": encoder_outputs
}
)[0]
# Get next token (greedy selection)
next_token_logits = decoder_outputs[0, -1, :] # Last token logits
next_token = np.argmax(next_token_logits)
# Check for end token
if next_token == self.config["eos_token_id"]:
break
generated_tokens.append(int(next_token))
# Update input for next iteration
decoder_input_ids = np.array([generated_tokens], dtype=np.int64)
decoder_time = time.time() - decoder_start_time
print(f"Decoder inference time: {decoder_time:.3f}s")
print(f"Generated {len(generated_tokens)} tokens")
# Decode tokens to text
transcription = self.processor.batch_decode(
[generated_tokens],
skip_special_tokens=True
)[0]
total_time = encoder_time + decoder_time
print(f"Total inference time: {total_time:.3f}s")
return transcription.strip()
def get_model_info(self):
"""Get model information"""
info = {
"model_type": self.config.get("model_type", "whisper"),
"vocab_size": self.config.get("vocab_size"),
"encoder_layers": self.config.get("encoder_layers"),
"decoder_layers": self.config.get("decoder_layers"),
"d_model": self.config.get("d_model"),
"encoder_file_size": self.encoder_path.stat().st_size / (1024**2), # MB
"decoder_file_size": self.decoder_path.stat().st_size / (1024**2), # MB
}
return info
def main():
"""Example usage"""
parser = argparse.ArgumentParser(description="Cahya Whisper ONNX Example")
parser.add_argument("--audio", type=str, required=True, help="Path to audio file")
parser.add_argument("--model-dir", type=str, default="./", help="Model directory")
parser.add_argument("--max-tokens", type=int, default=128, help="Max tokens to generate")
args = parser.parse_args()
# Check if audio file exists
if not os.path.exists(args.audio):
print(f"Error: Audio file not found: {args.audio}")
return
print("="*50)
print("Cahya Whisper Medium ONNX Example")
print("="*50)
try:
# Initialize model
model = CahyaWhisperONNX(args.model_dir)
# Show model info
print("\nModel Information:")
info = model.get_model_info()
for key, value in info.items():
if key.endswith('_size'):
print(f" {key}: {value:.1f} MB")
else:
print(f" {key}: {value}")
print(f"\nTranscribing: {args.audio}")
print("-" * 50)
# Transcribe
transcription = model.transcribe(args.audio, max_new_tokens=args.max_tokens)
print(f"\nTranscription:")
print(f"'{transcription}'")
print("-" * 50)
print("Done!")
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()