Automatic Speech Recognition
Transformers
Safetensors
Swahili
English
whisper
Generated from Trainer
Jacaranda commited on
Commit
d657b96
·
verified ·
1 Parent(s): 4fca98d

Upload Custom WER.py

Browse files
Files changed (1) hide show
  1. Custom WER.py +171 -0
Custom WER.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.feature_extraction.text import CountVectorizer
2
+ from sklearn.metrics.pairwise import cosine_similarity
3
+ import numpy as np
4
+ from typing import List, Optional
5
+ import pandas as pd
6
+
7
+ def cosine_sim_wer(references: List[str], predictions: List[str]) -> float:
8
+ """
9
+ Calculate a WER-like metric based on cosine similarity between reference and prediction texts.
10
+
11
+ This function computes character-level n-gram similarities between each reference-prediction pair
12
+ and returns an error rate (100% - average similarity). Handles empty inputs and provides
13
+ detailed similarity statistics.
14
+
15
+ Args:
16
+ references: List of reference transcript strings
17
+ predictions: List of model prediction strings
18
+
19
+ Returns:
20
+ float: Error rate based on cosine similarity (100% - average similarity)
21
+
22
+ Example:
23
+ >>> references = ["hello world", "good morning"]
24
+ >>> predictions = ["hello world", "good evening"]
25
+ >>> error_rate = cosine_sim_wer(references, predictions)
26
+ """
27
+ # Validate and clean inputs
28
+ valid_refs, valid_preds = [], []
29
+
30
+ for ref, pred in zip(references, predictions):
31
+ if not ref.strip() or not pred.strip():
32
+ continue # Skip empty strings
33
+ valid_refs.append(ref.strip())
34
+ valid_preds.append(pred.strip())
35
+
36
+ # Handle case with no valid pairs
37
+ if not valid_refs:
38
+ print("Warning: No valid reference-prediction pairs found")
39
+ return 100.0 # Maximum error if no valid data
40
+
41
+ # Calculate pairwise similarities
42
+ similarities = []
43
+ for ref, pred in zip(valid_refs, valid_preds):
44
+ try:
45
+ # Use character-level n-grams (2-3 chars) for robust comparison
46
+ vectorizer = CountVectorizer(
47
+ analyzer='char_wb', # Word-boundary aware character n-grams
48
+ ngram_range=(2, 3) # Bigrams and trigrams
49
+ )
50
+
51
+ # Create document-term matrices
52
+ vectors = vectorizer.fit_transform([ref, pred])
53
+
54
+ # Compute cosine similarity
55
+ similarity = cosine_similarity(vectors[0:1], vectors[1:2])[0][0]
56
+ similarities.append(similarity * 100) # Convert to percentage
57
+
58
+ except Exception as e:
59
+ print(f"Error calculating similarity: {e}")
60
+ similarities.append(0.0) # Default to 0% similarity on error
61
+
62
+ # Compute statistics
63
+ avg_similarity = np.mean(similarities)
64
+ min_similarity = np.min(similarities)
65
+ max_similarity = np.max(similarities)
66
+ error_rate = 100.0 - avg_similarity
67
+
68
+ # Print diagnostics
69
+ print(f"Similarity Statistics:")
70
+ print(f" - Average: {avg_similarity:.2f}%")
71
+ print(f" - Range: {min_similarity:.2f}% to {max_similarity:.2f}%")
72
+ print(f" - Valid samples: {len(similarities)}/{len(references)}")
73
+
74
+ return error_rate
75
+
76
+
77
+ def create_wer_analysis_dataframe(
78
+ references: List[str],
79
+ predictions: List[str],
80
+ normalized_references: Optional[List[str]] = None,
81
+ normalized_predictions: Optional[List[str]] = None,
82
+ output_csv: str = "wer_analysis.csv"
83
+ ) -> pd.DataFrame:
84
+ """
85
+ Create a comprehensive DataFrame comparing reference and prediction texts with multiple metrics.
86
+
87
+ For each sample, calculates:
88
+ - Word Error Rate (WER) for original and normalized texts
89
+ - Cosine similarity for original and normalized texts
90
+ - Length statistics and differences
91
+
92
+ Args:
93
+ references: List of original reference texts
94
+ predictions: List of original prediction texts
95
+ normalized_references: Optional list of normalized reference texts
96
+ normalized_predictions: Optional list of normalized prediction texts
97
+ output_csv: Path to save results (None to skip saving)
98
+
99
+ Returns:
100
+ pd.DataFrame: Analysis results with one row per sample
101
+
102
+ Example:
103
+ >>> df = create_wer_analysis_dataframe(
104
+ ... references=["hello world"],
105
+ ... predictions=["hello there"],
106
+ ... output_csv="analysis.csv"
107
+ ... )
108
+ """
109
+ from jiwer import wer # Import here to avoid dependency if not using WER
110
+
111
+ records = []
112
+
113
+ for i, (ref, pred) in enumerate(zip(references, predictions)):
114
+ # Skip empty samples
115
+ if not ref.strip() or not pred.strip():
116
+ continue
117
+
118
+ # Get normalized versions if provided
119
+ norm_ref = normalized_references[i] if normalized_references else ref
120
+ norm_pred = normalized_predictions[i] if normalized_predictions else pred
121
+
122
+ # Calculate metrics with error handling
123
+ metrics = {
124
+ 'index': i,
125
+ 'reference': ref,
126
+ 'prediction': pred,
127
+ 'normalized_reference': norm_ref,
128
+ 'normalized_prediction': norm_pred,
129
+ 'ref_length': len(ref.split()),
130
+ 'pred_length': len(pred.split()),
131
+ 'length_difference': len(pred.split()) - len(ref.split())
132
+ }
133
+
134
+ # Calculate WER metrics
135
+ try:
136
+ metrics['wer'] = wer(ref, pred) * 100
137
+ metrics['normalized_wer'] = wer(norm_ref, norm_pred) * 100
138
+ except Exception as e:
139
+ print(f"WER calculation failed for sample {i}: {e}")
140
+ metrics.update({'wer': np.nan, 'normalized_wer': np.nan})
141
+
142
+ # Calculate cosine similarities
143
+ for prefix, text1, text2 in [
144
+ ('', ref, pred),
145
+ ('normalized_', norm_ref, norm_pred)
146
+ ]:
147
+ try:
148
+ vectorizer = CountVectorizer(
149
+ analyzer='char_wb',
150
+ ngram_range=(2, 3)
151
+ vectors = vectorizer.fit_transform([text1, text2])
152
+ similarity = cosine_similarity(vectors[0:1], vectors[1:2])[0][0] * 100
153
+ metrics[f'{prefix}similarity'] = similarity
154
+ except Exception as e:
155
+ print(f"Similarity calculation failed for sample {i}: {e}")
156
+ metrics[f'{prefix}similarity'] = np.nan
157
+
158
+ records.append(metrics)
159
+
160
+ # Create DataFrame
161
+ df = pd.DataFrame(records)
162
+
163
+ # Save to CSV if requested
164
+ if output_csv:
165
+ try:
166
+ df.to_csv(output_csv, index=False)
167
+ print(f"Analysis saved to {output_csv}")
168
+ except Exception as e:
169
+ print(f"Failed to save CSV: {e}")
170
+
171
+ return df