Spaces:
Running
Running
File size: 5,315 Bytes
38024bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
# This module handles model inference and evaluation.
from datetime import datetime
from typing import Optional
import torch
from transformers import AutoProcessor, AutoModelForCTC
from data import timit_manager
from phone_metrics import PhoneErrorMetrics
# Initialize evaluation metric
phone_errors = PhoneErrorMetrics()
class ModelManager:
"""Handles model loading and inference"""
def __init__(self):
self.models = {}
self.processors = {}
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.batch_size = 32
def get_model_and_processor(self, model_name: str):
"""Get or load model and processor"""
if model_name not in self.models:
print("Loading processor with phoneme tokenizer...")
processor = AutoProcessor.from_pretrained(model_name)
print("Loading model...", {model_name})
model = AutoModelForCTC.from_pretrained(model_name).to(self.device)
self.models[model_name] = model
self.processors[model_name] = processor
return self.models[model_name], self.processors[model_name]
def transcribe(self, audio_list: list[torch.Tensor], model_name: str) -> list[str]:
"""Transcribe a batch of audio using specified model"""
model, processor = self.get_model_and_processor(model_name)
if not model or not processor:
raise Exception("Model and processor not loaded")
# Process audio in batches
all_predictions = []
for i in range(0, len(audio_list), self.batch_size):
batch_audio = audio_list[i : i + self.batch_size]
# Pad sequence within batch
max_length = max(audio.shape[-1] for audio in batch_audio)
padded_audio = torch.zeros((len(batch_audio), 1, max_length))
attention_mask = torch.zeros((len(batch_audio), max_length))
for j, audio in enumerate(batch_audio):
padded_audio[j, :, : audio.shape[-1]] = audio
attention_mask[j, : audio.shape[-1]] = 1
# Process batch
inputs = processor(
padded_audio.squeeze(1).numpy(),
sampling_rate=16000,
return_tensors="pt",
padding=True,
)
input_values = inputs.input_values.to(self.device)
attention_mask = inputs.get("attention_mask", attention_mask).to(
self.device
)
with torch.no_grad():
outputs = model(
input_values=input_values, attention_mask=attention_mask
)
logits = outputs.logits
predicted_ids = torch.argmax(logits, dim=-1)
predictions = processor.batch_decode(
predicted_ids, skip_special_tokens=True
)
predictions = [pred.replace(" ", "") for pred in predictions]
all_predictions.extend(predictions)
return all_predictions
def evaluate_model(
model_name: str,
subset: str = "test",
max_samples: Optional[int] = None,
):
"""Evaluate model on TIMIT dataset"""
files = timit_manager.get_file_list(subset)
if max_samples:
files = files[:max_samples]
results = []
total_per = total_pwed = 0
# Process files in batches
batch_size = model_manager.batch_size
for i in range(0, len(files), batch_size):
batch_files = files[i : i + batch_size]
# Load batch audio and ground truth
batch_audio = []
batch_ground_truth = []
for wav_file in batch_files:
audio = timit_manager.load_audio(wav_file)
ground_truth = timit_manager.get_phonemes(wav_file)
batch_audio.append(audio)
batch_ground_truth.append(ground_truth)
# Get predictions for batch
predictions = model_manager.transcribe(batch_audio, model_name)
# Calculate metrics for each file in batch
for _, (wav_file, prediction, ground_truth) in enumerate(
zip(batch_files, predictions, batch_ground_truth)
):
metrics = phone_errors.compute(
predictions=[prediction],
references=[ground_truth],
is_normalize_pfer=True,
)
per = metrics["phone_error_rates"][0]
pwed = metrics["phone_feature_error_rates"][0]
results.append(
{
"file": wav_file,
"ground_truth": ground_truth,
"prediction": prediction,
"per": per,
"pwed": pwed,
}
)
total_per += per
total_pwed += pwed
if not results:
raise Exception("No files were successfully processed")
avg_per = total_per / len(results)
avg_pwed = total_pwed / len(results)
return {
"model": model_name,
"subset": subset,
"num_files": len(results),
"average_per": avg_per,
"average_pwed": avg_pwed,
"detailed_results": results[:5],
"timestamp": datetime.now().isoformat(),
}
# Initialize managers
model_manager = ModelManager()
|