|
|
|
""" |
|
Example script demonstrating how to use the Cahya Whisper Medium ONNX model |
|
for Indonesian speech recognition. |
|
|
|
This script shows how to: |
|
1. Load the quantized ONNX model (encoder + decoder) |
|
2. Process audio files for inference |
|
3. Generate transcriptions |
|
|
|
Requirements: |
|
- onnxruntime |
|
- transformers |
|
- librosa |
|
- numpy |
|
""" |
|
|
|
import os |
|
import json |
|
import numpy as np |
|
import librosa |
|
import onnxruntime as ort |
|
from transformers import WhisperProcessor |
|
from pathlib import Path |
|
import argparse |
|
import time |
|
|
|
class CahyaWhisperONNX: |
|
"""ONNX inference wrapper for Cahya Whisper Medium Indonesian model""" |
|
|
|
def __init__(self, model_dir="./"): |
|
""" |
|
Initialize the ONNX Whisper model |
|
|
|
Args: |
|
model_dir (str): Directory containing the ONNX model files |
|
""" |
|
self.model_dir = Path(model_dir) |
|
self.encoder_path = self.model_dir / "encoder_model_quantized.onnx" |
|
self.decoder_path = self.model_dir / "decoder_model_quantized.onnx" |
|
self.config_path = self.model_dir / "config.json" |
|
|
|
|
|
if not self.encoder_path.exists(): |
|
raise FileNotFoundError(f"Encoder model not found: {self.encoder_path}") |
|
if not self.decoder_path.exists(): |
|
raise FileNotFoundError(f"Decoder model not found: {self.decoder_path}") |
|
if not self.config_path.exists(): |
|
raise FileNotFoundError(f"Config file not found: {self.config_path}") |
|
|
|
|
|
print("Loading ONNX models...") |
|
|
|
|
|
session_options = ort.SessionOptions() |
|
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL |
|
|
|
|
|
providers = ['CPUExecutionProvider'] |
|
|
|
try: |
|
self.encoder_session = ort.InferenceSession( |
|
str(self.encoder_path), |
|
sess_options=session_options, |
|
providers=providers |
|
) |
|
print("✓ Encoder model loaded successfully") |
|
except Exception as e: |
|
print(f"✗ Failed to load encoder: {e}") |
|
raise |
|
|
|
try: |
|
self.decoder_session = ort.InferenceSession( |
|
str(self.decoder_path), |
|
sess_options=session_options, |
|
providers=providers |
|
) |
|
print("✓ Decoder model loaded successfully") |
|
except Exception as e: |
|
print(f"✗ Failed to load decoder: {e}") |
|
raise |
|
|
|
|
|
print("Loading processor...") |
|
self.processor = WhisperProcessor.from_pretrained("openai/whisper-medium") |
|
|
|
|
|
with open(self.config_path, 'r') as f: |
|
self.config = json.load(f) |
|
|
|
print("Model loaded successfully!") |
|
print(f"Model type: {self.config.get('model_type', 'whisper')}") |
|
print(f"Vocab size: {self.config.get('vocab_size', 'unknown')}") |
|
|
|
def preprocess_audio(self, audio_path, max_duration=30.0): |
|
""" |
|
Preprocess audio file for inference |
|
|
|
Args: |
|
audio_path (str): Path to audio file |
|
max_duration (float): Maximum audio duration in seconds |
|
|
|
Returns: |
|
np.ndarray: Preprocessed audio features |
|
""" |
|
|
|
audio, sr = librosa.load(audio_path, sr=16000) |
|
|
|
|
|
max_samples = int(max_duration * 16000) |
|
if len(audio) > max_samples: |
|
audio = audio[:max_samples] |
|
print(f"Audio trimmed to {max_duration} seconds") |
|
|
|
print(f"Audio duration: {len(audio) / 16000:.2f} seconds") |
|
return audio |
|
|
|
def transcribe(self, audio_input, max_new_tokens=128): |
|
""" |
|
Transcribe audio to text |
|
|
|
Args: |
|
audio_input: Audio array or path to audio file |
|
max_new_tokens (int): Maximum number of tokens to generate |
|
|
|
Returns: |
|
str: Transcribed text |
|
""" |
|
|
|
if isinstance(audio_input, str): |
|
audio_array = self.preprocess_audio(audio_input) |
|
else: |
|
audio_array = audio_input |
|
|
|
|
|
input_features = self.processor( |
|
audio_array, |
|
sampling_rate=16000, |
|
return_tensors="np" |
|
).input_features |
|
|
|
print(f"Input features shape: {input_features.shape}") |
|
|
|
|
|
print("Running encoder...") |
|
start_time = time.time() |
|
encoder_outputs = self.encoder_session.run( |
|
None, |
|
{"input_features": input_features} |
|
)[0] |
|
encoder_time = time.time() - start_time |
|
print(f"Encoder inference time: {encoder_time:.3f}s") |
|
print(f"Encoder output shape: {encoder_outputs.shape}") |
|
|
|
|
|
decoder_input_ids = np.array([[self.config["decoder_start_token_id"]]], dtype=np.int64) |
|
generated_tokens = [self.config["decoder_start_token_id"]] |
|
|
|
print("Running decoder...") |
|
decoder_start_time = time.time() |
|
|
|
|
|
for step in range(max_new_tokens): |
|
|
|
decoder_outputs = self.decoder_session.run( |
|
None, |
|
{ |
|
"input_ids": decoder_input_ids, |
|
"encoder_hidden_states": encoder_outputs |
|
} |
|
)[0] |
|
|
|
|
|
next_token_logits = decoder_outputs[0, -1, :] |
|
next_token = np.argmax(next_token_logits) |
|
|
|
|
|
if next_token == self.config["eos_token_id"]: |
|
break |
|
|
|
generated_tokens.append(int(next_token)) |
|
|
|
|
|
decoder_input_ids = np.array([generated_tokens], dtype=np.int64) |
|
|
|
decoder_time = time.time() - decoder_start_time |
|
print(f"Decoder inference time: {decoder_time:.3f}s") |
|
print(f"Generated {len(generated_tokens)} tokens") |
|
|
|
|
|
transcription = self.processor.batch_decode( |
|
[generated_tokens], |
|
skip_special_tokens=True |
|
)[0] |
|
|
|
total_time = encoder_time + decoder_time |
|
print(f"Total inference time: {total_time:.3f}s") |
|
|
|
return transcription.strip() |
|
|
|
def get_model_info(self): |
|
"""Get model information""" |
|
info = { |
|
"model_type": self.config.get("model_type", "whisper"), |
|
"vocab_size": self.config.get("vocab_size"), |
|
"encoder_layers": self.config.get("encoder_layers"), |
|
"decoder_layers": self.config.get("decoder_layers"), |
|
"d_model": self.config.get("d_model"), |
|
"encoder_file_size": self.encoder_path.stat().st_size / (1024**2), |
|
"decoder_file_size": self.decoder_path.stat().st_size / (1024**2), |
|
} |
|
return info |
|
|
|
def main(): |
|
"""Example usage""" |
|
parser = argparse.ArgumentParser(description="Cahya Whisper ONNX Example") |
|
parser.add_argument("--audio", type=str, required=True, help="Path to audio file") |
|
parser.add_argument("--model-dir", type=str, default="./", help="Model directory") |
|
parser.add_argument("--max-tokens", type=int, default=128, help="Max tokens to generate") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if not os.path.exists(args.audio): |
|
print(f"Error: Audio file not found: {args.audio}") |
|
return |
|
|
|
print("="*50) |
|
print("Cahya Whisper Medium ONNX Example") |
|
print("="*50) |
|
|
|
try: |
|
|
|
model = CahyaWhisperONNX(args.model_dir) |
|
|
|
|
|
print("\nModel Information:") |
|
info = model.get_model_info() |
|
for key, value in info.items(): |
|
if key.endswith('_size'): |
|
print(f" {key}: {value:.1f} MB") |
|
else: |
|
print(f" {key}: {value}") |
|
|
|
print(f"\nTranscribing: {args.audio}") |
|
print("-" * 50) |
|
|
|
|
|
transcription = model.transcribe(args.audio, max_new_tokens=args.max_tokens) |
|
|
|
print(f"\nTranscription:") |
|
print(f"'{transcription}'") |
|
print("-" * 50) |
|
print("Done!") |
|
|
|
except Exception as e: |
|
print(f"Error: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
|
|
if __name__ == "__main__": |
|
main() |