#!/usr/bin/env python3 """ Example script demonstrating how to use the Cahya Whisper Medium ONNX model for Indonesian speech recognition. This script shows how to: 1. Load the quantized ONNX model (encoder + decoder) 2. Process audio files for inference 3. Generate transcriptions Requirements: - onnxruntime - transformers - librosa - numpy """ import os import json import numpy as np import librosa import onnxruntime as ort from transformers import WhisperProcessor from pathlib import Path import argparse import time class CahyaWhisperONNX: """ONNX inference wrapper for Cahya Whisper Medium Indonesian model""" def __init__(self, model_dir="./"): """ Initialize the ONNX Whisper model Args: model_dir (str): Directory containing the ONNX model files """ self.model_dir = Path(model_dir) self.encoder_path = self.model_dir / "encoder_model_quantized.onnx" self.decoder_path = self.model_dir / "decoder_model_quantized.onnx" self.config_path = self.model_dir / "config.json" # Validate model files exist if not self.encoder_path.exists(): raise FileNotFoundError(f"Encoder model not found: {self.encoder_path}") if not self.decoder_path.exists(): raise FileNotFoundError(f"Decoder model not found: {self.decoder_path}") if not self.config_path.exists(): raise FileNotFoundError(f"Config file not found: {self.config_path}") # Load ONNX models with quantization support print("Loading ONNX models...") # Configure session options for quantized models session_options = ort.SessionOptions() session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL # Try different execution providers for quantized models providers = ['CPUExecutionProvider'] try: self.encoder_session = ort.InferenceSession( str(self.encoder_path), sess_options=session_options, providers=providers ) print("✓ Encoder model loaded successfully") except Exception as e: print(f"✗ Failed to load encoder: {e}") raise try: self.decoder_session = ort.InferenceSession( str(self.decoder_path), sess_options=session_options, providers=providers ) print("✓ Decoder model loaded successfully") except Exception as e: print(f"✗ Failed to load decoder: {e}") raise # Load processor for tokenization (using base Whisper processor) print("Loading processor...") self.processor = WhisperProcessor.from_pretrained("openai/whisper-medium") # Load model config with open(self.config_path, 'r') as f: self.config = json.load(f) print("Model loaded successfully!") print(f"Model type: {self.config.get('model_type', 'whisper')}") print(f"Vocab size: {self.config.get('vocab_size', 'unknown')}") def preprocess_audio(self, audio_path, max_duration=30.0): """ Preprocess audio file for inference Args: audio_path (str): Path to audio file max_duration (float): Maximum audio duration in seconds Returns: np.ndarray: Preprocessed audio features """ # Load audio audio, sr = librosa.load(audio_path, sr=16000) # Trim to max duration max_samples = int(max_duration * 16000) if len(audio) > max_samples: audio = audio[:max_samples] print(f"Audio trimmed to {max_duration} seconds") print(f"Audio duration: {len(audio) / 16000:.2f} seconds") return audio def transcribe(self, audio_input, max_new_tokens=128): """ Transcribe audio to text Args: audio_input: Audio array or path to audio file max_new_tokens (int): Maximum number of tokens to generate Returns: str: Transcribed text """ # Handle both file path and audio array inputs if isinstance(audio_input, str): audio_array = self.preprocess_audio(audio_input) else: audio_array = audio_input # Prepare input features input_features = self.processor( audio_array, sampling_rate=16000, return_tensors="np" ).input_features print(f"Input features shape: {input_features.shape}") # Encoder forward pass print("Running encoder...") start_time = time.time() encoder_outputs = self.encoder_session.run( None, {"input_features": input_features} )[0] encoder_time = time.time() - start_time print(f"Encoder inference time: {encoder_time:.3f}s") print(f"Encoder output shape: {encoder_outputs.shape}") # Initialize decoder with start token decoder_input_ids = np.array([[self.config["decoder_start_token_id"]]], dtype=np.int64) generated_tokens = [self.config["decoder_start_token_id"]] print("Running decoder...") decoder_start_time = time.time() # Simple greedy decoding (for demonstration) for step in range(max_new_tokens): # Decoder forward pass decoder_outputs = self.decoder_session.run( None, { "input_ids": decoder_input_ids, "encoder_hidden_states": encoder_outputs } )[0] # Get next token (greedy selection) next_token_logits = decoder_outputs[0, -1, :] # Last token logits next_token = np.argmax(next_token_logits) # Check for end token if next_token == self.config["eos_token_id"]: break generated_tokens.append(int(next_token)) # Update input for next iteration decoder_input_ids = np.array([generated_tokens], dtype=np.int64) decoder_time = time.time() - decoder_start_time print(f"Decoder inference time: {decoder_time:.3f}s") print(f"Generated {len(generated_tokens)} tokens") # Decode tokens to text transcription = self.processor.batch_decode( [generated_tokens], skip_special_tokens=True )[0] total_time = encoder_time + decoder_time print(f"Total inference time: {total_time:.3f}s") return transcription.strip() def get_model_info(self): """Get model information""" info = { "model_type": self.config.get("model_type", "whisper"), "vocab_size": self.config.get("vocab_size"), "encoder_layers": self.config.get("encoder_layers"), "decoder_layers": self.config.get("decoder_layers"), "d_model": self.config.get("d_model"), "encoder_file_size": self.encoder_path.stat().st_size / (1024**2), # MB "decoder_file_size": self.decoder_path.stat().st_size / (1024**2), # MB } return info def main(): """Example usage""" parser = argparse.ArgumentParser(description="Cahya Whisper ONNX Example") parser.add_argument("--audio", type=str, required=True, help="Path to audio file") parser.add_argument("--model-dir", type=str, default="./", help="Model directory") parser.add_argument("--max-tokens", type=int, default=128, help="Max tokens to generate") args = parser.parse_args() # Check if audio file exists if not os.path.exists(args.audio): print(f"Error: Audio file not found: {args.audio}") return print("="*50) print("Cahya Whisper Medium ONNX Example") print("="*50) try: # Initialize model model = CahyaWhisperONNX(args.model_dir) # Show model info print("\nModel Information:") info = model.get_model_info() for key, value in info.items(): if key.endswith('_size'): print(f" {key}: {value:.1f} MB") else: print(f" {key}: {value}") print(f"\nTranscribing: {args.audio}") print("-" * 50) # Transcribe transcription = model.transcribe(args.audio, max_new_tokens=args.max_tokens) print(f"\nTranscription:") print(f"'{transcription}'") print("-" * 50) print("Done!") except Exception as e: print(f"Error: {e}") import traceback traceback.print_exc() if __name__ == "__main__": main()