IPA-Transcription-EN / app /inference.py
SanderGi's picture
clean up and make contribution ready
38024bc
# This module handles model inference and evaluation.
from datetime import datetime
from typing import Optional
import torch
from transformers import AutoProcessor, AutoModelForCTC
from data import timit_manager
from phone_metrics import PhoneErrorMetrics
# Initialize evaluation metric
phone_errors = PhoneErrorMetrics()
class ModelManager:
"""Handles model loading and inference"""
def __init__(self):
self.models = {}
self.processors = {}
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.batch_size = 32
def get_model_and_processor(self, model_name: str):
"""Get or load model and processor"""
if model_name not in self.models:
print("Loading processor with phoneme tokenizer...")
processor = AutoProcessor.from_pretrained(model_name)
print("Loading model...", {model_name})
model = AutoModelForCTC.from_pretrained(model_name).to(self.device)
self.models[model_name] = model
self.processors[model_name] = processor
return self.models[model_name], self.processors[model_name]
def transcribe(self, audio_list: list[torch.Tensor], model_name: str) -> list[str]:
"""Transcribe a batch of audio using specified model"""
model, processor = self.get_model_and_processor(model_name)
if not model or not processor:
raise Exception("Model and processor not loaded")
# Process audio in batches
all_predictions = []
for i in range(0, len(audio_list), self.batch_size):
batch_audio = audio_list[i : i + self.batch_size]
# Pad sequence within batch
max_length = max(audio.shape[-1] for audio in batch_audio)
padded_audio = torch.zeros((len(batch_audio), 1, max_length))
attention_mask = torch.zeros((len(batch_audio), max_length))
for j, audio in enumerate(batch_audio):
padded_audio[j, :, : audio.shape[-1]] = audio
attention_mask[j, : audio.shape[-1]] = 1
# Process batch
inputs = processor(
padded_audio.squeeze(1).numpy(),
sampling_rate=16000,
return_tensors="pt",
padding=True,
)
input_values = inputs.input_values.to(self.device)
attention_mask = inputs.get("attention_mask", attention_mask).to(
self.device
)
with torch.no_grad():
outputs = model(
input_values=input_values, attention_mask=attention_mask
)
logits = outputs.logits
predicted_ids = torch.argmax(logits, dim=-1)
predictions = processor.batch_decode(
predicted_ids, skip_special_tokens=True
)
predictions = [pred.replace(" ", "") for pred in predictions]
all_predictions.extend(predictions)
return all_predictions
def evaluate_model(
model_name: str,
subset: str = "test",
max_samples: Optional[int] = None,
):
"""Evaluate model on TIMIT dataset"""
files = timit_manager.get_file_list(subset)
if max_samples:
files = files[:max_samples]
results = []
total_per = total_pwed = 0
# Process files in batches
batch_size = model_manager.batch_size
for i in range(0, len(files), batch_size):
batch_files = files[i : i + batch_size]
# Load batch audio and ground truth
batch_audio = []
batch_ground_truth = []
for wav_file in batch_files:
audio = timit_manager.load_audio(wav_file)
ground_truth = timit_manager.get_phonemes(wav_file)
batch_audio.append(audio)
batch_ground_truth.append(ground_truth)
# Get predictions for batch
predictions = model_manager.transcribe(batch_audio, model_name)
# Calculate metrics for each file in batch
for _, (wav_file, prediction, ground_truth) in enumerate(
zip(batch_files, predictions, batch_ground_truth)
):
metrics = phone_errors.compute(
predictions=[prediction],
references=[ground_truth],
is_normalize_pfer=True,
)
per = metrics["phone_error_rates"][0]
pwed = metrics["phone_feature_error_rates"][0]
results.append(
{
"file": wav_file,
"ground_truth": ground_truth,
"prediction": prediction,
"per": per,
"pwed": pwed,
}
)
total_per += per
total_pwed += pwed
if not results:
raise Exception("No files were successfully processed")
avg_per = total_per / len(results)
avg_pwed = total_pwed / len(results)
return {
"model": model_name,
"subset": subset,
"num_files": len(results),
"average_per": avg_per,
"average_pwed": avg_pwed,
"detailed_results": results[:5],
"timestamp": datetime.now().isoformat(),
}
# Initialize managers
model_manager = ModelManager()