# This module handles model inference and evaluation. from datetime import datetime from typing import Optional import torch from transformers import AutoProcessor, AutoModelForCTC from data import timit_manager from phone_metrics import PhoneErrorMetrics # Initialize evaluation metric phone_errors = PhoneErrorMetrics() class ModelManager: """Handles model loading and inference""" def __init__(self): self.models = {} self.processors = {} self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.batch_size = 32 def get_model_and_processor(self, model_name: str): """Get or load model and processor""" if model_name not in self.models: print("Loading processor with phoneme tokenizer...") processor = AutoProcessor.from_pretrained(model_name) print("Loading model...", {model_name}) model = AutoModelForCTC.from_pretrained(model_name).to(self.device) self.models[model_name] = model self.processors[model_name] = processor return self.models[model_name], self.processors[model_name] def transcribe(self, audio_list: list[torch.Tensor], model_name: str) -> list[str]: """Transcribe a batch of audio using specified model""" model, processor = self.get_model_and_processor(model_name) if not model or not processor: raise Exception("Model and processor not loaded") # Process audio in batches all_predictions = [] for i in range(0, len(audio_list), self.batch_size): batch_audio = audio_list[i : i + self.batch_size] # Pad sequence within batch max_length = max(audio.shape[-1] for audio in batch_audio) padded_audio = torch.zeros((len(batch_audio), 1, max_length)) attention_mask = torch.zeros((len(batch_audio), max_length)) for j, audio in enumerate(batch_audio): padded_audio[j, :, : audio.shape[-1]] = audio attention_mask[j, : audio.shape[-1]] = 1 # Process batch inputs = processor( padded_audio.squeeze(1).numpy(), sampling_rate=16000, return_tensors="pt", padding=True, ) input_values = inputs.input_values.to(self.device) attention_mask = inputs.get("attention_mask", attention_mask).to( self.device ) with torch.no_grad(): outputs = model( input_values=input_values, attention_mask=attention_mask ) logits = outputs.logits predicted_ids = torch.argmax(logits, dim=-1) predictions = processor.batch_decode( predicted_ids, skip_special_tokens=True ) predictions = [pred.replace(" ", "") for pred in predictions] all_predictions.extend(predictions) return all_predictions def evaluate_model( model_name: str, subset: str = "test", max_samples: Optional[int] = None, ): """Evaluate model on TIMIT dataset""" files = timit_manager.get_file_list(subset) if max_samples: files = files[:max_samples] results = [] total_per = total_pwed = 0 # Process files in batches batch_size = model_manager.batch_size for i in range(0, len(files), batch_size): batch_files = files[i : i + batch_size] # Load batch audio and ground truth batch_audio = [] batch_ground_truth = [] for wav_file in batch_files: audio = timit_manager.load_audio(wav_file) ground_truth = timit_manager.get_phonemes(wav_file) batch_audio.append(audio) batch_ground_truth.append(ground_truth) # Get predictions for batch predictions = model_manager.transcribe(batch_audio, model_name) # Calculate metrics for each file in batch for _, (wav_file, prediction, ground_truth) in enumerate( zip(batch_files, predictions, batch_ground_truth) ): metrics = phone_errors.compute( predictions=[prediction], references=[ground_truth], is_normalize_pfer=True, ) per = metrics["phone_error_rates"][0] pwed = metrics["phone_feature_error_rates"][0] results.append( { "file": wav_file, "ground_truth": ground_truth, "prediction": prediction, "per": per, "pwed": pwed, } ) total_per += per total_pwed += pwed if not results: raise Exception("No files were successfully processed") avg_per = total_per / len(results) avg_pwed = total_pwed / len(results) return { "model": model_name, "subset": subset, "num_files": len(results), "average_per": avg_per, "average_pwed": avg_pwed, "detailed_results": results[:5], "timestamp": datetime.now().isoformat(), } # Initialize managers model_manager = ModelManager()