File size: 5,315 Bytes
38024bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# This module handles model inference and evaluation.

from datetime import datetime
from typing import Optional

import torch
from transformers import AutoProcessor, AutoModelForCTC

from data import timit_manager
from phone_metrics import PhoneErrorMetrics

# Initialize evaluation metric
phone_errors = PhoneErrorMetrics()


class ModelManager:
    """Handles model loading and inference"""

    def __init__(self):
        self.models = {}
        self.processors = {}
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.batch_size = 32

    def get_model_and_processor(self, model_name: str):
        """Get or load model and processor"""
        if model_name not in self.models:
            print("Loading processor with phoneme tokenizer...")
            processor = AutoProcessor.from_pretrained(model_name)

            print("Loading model...", {model_name})
            model = AutoModelForCTC.from_pretrained(model_name).to(self.device)

            self.models[model_name] = model
            self.processors[model_name] = processor

        return self.models[model_name], self.processors[model_name]

    def transcribe(self, audio_list: list[torch.Tensor], model_name: str) -> list[str]:
        """Transcribe a batch of audio using specified model"""
        model, processor = self.get_model_and_processor(model_name)
        if not model or not processor:
            raise Exception("Model and processor not loaded")

        # Process audio in batches
        all_predictions = []
        for i in range(0, len(audio_list), self.batch_size):
            batch_audio = audio_list[i : i + self.batch_size]

            # Pad sequence within batch
            max_length = max(audio.shape[-1] for audio in batch_audio)
            padded_audio = torch.zeros((len(batch_audio), 1, max_length))
            attention_mask = torch.zeros((len(batch_audio), max_length))

            for j, audio in enumerate(batch_audio):
                padded_audio[j, :, : audio.shape[-1]] = audio
                attention_mask[j, : audio.shape[-1]] = 1

            # Process batch
            inputs = processor(
                padded_audio.squeeze(1).numpy(),
                sampling_rate=16000,
                return_tensors="pt",
                padding=True,
            )

            input_values = inputs.input_values.to(self.device)
            attention_mask = inputs.get("attention_mask", attention_mask).to(
                self.device
            )

            with torch.no_grad():
                outputs = model(
                    input_values=input_values, attention_mask=attention_mask
                )
                logits = outputs.logits
                predicted_ids = torch.argmax(logits, dim=-1)
                predictions = processor.batch_decode(
                    predicted_ids, skip_special_tokens=True
                )
                predictions = [pred.replace(" ", "") for pred in predictions]
                all_predictions.extend(predictions)

        return all_predictions


def evaluate_model(
    model_name: str,
    subset: str = "test",
    max_samples: Optional[int] = None,
):
    """Evaluate model on TIMIT dataset"""

    files = timit_manager.get_file_list(subset)
    if max_samples:
        files = files[:max_samples]

    results = []
    total_per = total_pwed = 0

    # Process files in batches
    batch_size = model_manager.batch_size
    for i in range(0, len(files), batch_size):
        batch_files = files[i : i + batch_size]

        # Load batch audio and ground truth
        batch_audio = []
        batch_ground_truth = []
        for wav_file in batch_files:
            audio = timit_manager.load_audio(wav_file)
            ground_truth = timit_manager.get_phonemes(wav_file)
            batch_audio.append(audio)
            batch_ground_truth.append(ground_truth)

        # Get predictions for batch
        predictions = model_manager.transcribe(batch_audio, model_name)

        # Calculate metrics for each file in batch
        for _, (wav_file, prediction, ground_truth) in enumerate(
            zip(batch_files, predictions, batch_ground_truth)
        ):
            metrics = phone_errors.compute(
                predictions=[prediction],
                references=[ground_truth],
                is_normalize_pfer=True,
            )

            per = metrics["phone_error_rates"][0]
            pwed = metrics["phone_feature_error_rates"][0]

            results.append(
                {
                    "file": wav_file,
                    "ground_truth": ground_truth,
                    "prediction": prediction,
                    "per": per,
                    "pwed": pwed,
                }
            )

            total_per += per
            total_pwed += pwed

    if not results:
        raise Exception("No files were successfully processed")

    avg_per = total_per / len(results)
    avg_pwed = total_pwed / len(results)

    return {
        "model": model_name,
        "subset": subset,
        "num_files": len(results),
        "average_per": avg_per,
        "average_pwed": avg_pwed,
        "detailed_results": results[:5],
        "timestamp": datetime.now().isoformat(),
    }


# Initialize managers
model_manager = ModelManager()