Spaces:
Running
Running
# This module handles model inference and evaluation. | |
from datetime import datetime | |
from typing import Optional | |
import torch | |
from transformers import AutoProcessor, AutoModelForCTC | |
from data import timit_manager | |
from phone_metrics import PhoneErrorMetrics | |
# Initialize evaluation metric | |
phone_errors = PhoneErrorMetrics() | |
class ModelManager: | |
"""Handles model loading and inference""" | |
def __init__(self): | |
self.models = {} | |
self.processors = {} | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.batch_size = 32 | |
def get_model_and_processor(self, model_name: str): | |
"""Get or load model and processor""" | |
if model_name not in self.models: | |
print("Loading processor with phoneme tokenizer...") | |
processor = AutoProcessor.from_pretrained(model_name) | |
print("Loading model...", {model_name}) | |
model = AutoModelForCTC.from_pretrained(model_name).to(self.device) | |
self.models[model_name] = model | |
self.processors[model_name] = processor | |
return self.models[model_name], self.processors[model_name] | |
def transcribe(self, audio_list: list[torch.Tensor], model_name: str) -> list[str]: | |
"""Transcribe a batch of audio using specified model""" | |
model, processor = self.get_model_and_processor(model_name) | |
if not model or not processor: | |
raise Exception("Model and processor not loaded") | |
# Process audio in batches | |
all_predictions = [] | |
for i in range(0, len(audio_list), self.batch_size): | |
batch_audio = audio_list[i : i + self.batch_size] | |
# Pad sequence within batch | |
max_length = max(audio.shape[-1] for audio in batch_audio) | |
padded_audio = torch.zeros((len(batch_audio), 1, max_length)) | |
attention_mask = torch.zeros((len(batch_audio), max_length)) | |
for j, audio in enumerate(batch_audio): | |
padded_audio[j, :, : audio.shape[-1]] = audio | |
attention_mask[j, : audio.shape[-1]] = 1 | |
# Process batch | |
inputs = processor( | |
padded_audio.squeeze(1).numpy(), | |
sampling_rate=16000, | |
return_tensors="pt", | |
padding=True, | |
) | |
input_values = inputs.input_values.to(self.device) | |
attention_mask = inputs.get("attention_mask", attention_mask).to( | |
self.device | |
) | |
with torch.no_grad(): | |
outputs = model( | |
input_values=input_values, attention_mask=attention_mask | |
) | |
logits = outputs.logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
predictions = processor.batch_decode( | |
predicted_ids, skip_special_tokens=True | |
) | |
predictions = [pred.replace(" ", "") for pred in predictions] | |
all_predictions.extend(predictions) | |
return all_predictions | |
def evaluate_model( | |
model_name: str, | |
subset: str = "test", | |
max_samples: Optional[int] = None, | |
): | |
"""Evaluate model on TIMIT dataset""" | |
files = timit_manager.get_file_list(subset) | |
if max_samples: | |
files = files[:max_samples] | |
results = [] | |
total_per = total_pwed = 0 | |
# Process files in batches | |
batch_size = model_manager.batch_size | |
for i in range(0, len(files), batch_size): | |
batch_files = files[i : i + batch_size] | |
# Load batch audio and ground truth | |
batch_audio = [] | |
batch_ground_truth = [] | |
for wav_file in batch_files: | |
audio = timit_manager.load_audio(wav_file) | |
ground_truth = timit_manager.get_phonemes(wav_file) | |
batch_audio.append(audio) | |
batch_ground_truth.append(ground_truth) | |
# Get predictions for batch | |
predictions = model_manager.transcribe(batch_audio, model_name) | |
# Calculate metrics for each file in batch | |
for _, (wav_file, prediction, ground_truth) in enumerate( | |
zip(batch_files, predictions, batch_ground_truth) | |
): | |
metrics = phone_errors.compute( | |
predictions=[prediction], | |
references=[ground_truth], | |
is_normalize_pfer=True, | |
) | |
per = metrics["phone_error_rates"][0] | |
pwed = metrics["phone_feature_error_rates"][0] | |
results.append( | |
{ | |
"file": wav_file, | |
"ground_truth": ground_truth, | |
"prediction": prediction, | |
"per": per, | |
"pwed": pwed, | |
} | |
) | |
total_per += per | |
total_pwed += pwed | |
if not results: | |
raise Exception("No files were successfully processed") | |
avg_per = total_per / len(results) | |
avg_pwed = total_pwed / len(results) | |
return { | |
"model": model_name, | |
"subset": subset, | |
"num_files": len(results), | |
"average_per": avg_per, | |
"average_pwed": avg_pwed, | |
"detailed_results": results[:5], | |
"timestamp": datetime.now().isoformat(), | |
} | |
# Initialize managers | |
model_manager = ModelManager() | |