|
from sklearn.feature_extraction.text import CountVectorizer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import numpy as np |
|
from typing import List, Optional |
|
import pandas as pd |
|
|
|
def cosine_sim_wer(references: List[str], predictions: List[str]) -> float: |
|
""" |
|
Calculate a WER-like metric based on cosine similarity between reference and prediction texts. |
|
|
|
This function computes character-level n-gram similarities between each reference-prediction pair |
|
and returns an error rate (100% - average similarity). Handles empty inputs and provides |
|
detailed similarity statistics. |
|
|
|
Args: |
|
references: List of reference transcript strings |
|
predictions: List of model prediction strings |
|
|
|
Returns: |
|
float: Error rate based on cosine similarity (100% - average similarity) |
|
|
|
Example: |
|
>>> references = ["hello world", "good morning"] |
|
>>> predictions = ["hello world", "good evening"] |
|
>>> error_rate = cosine_sim_wer(references, predictions) |
|
""" |
|
|
|
valid_refs, valid_preds = [], [] |
|
|
|
for ref, pred in zip(references, predictions): |
|
if not ref.strip() or not pred.strip(): |
|
continue |
|
valid_refs.append(ref.strip()) |
|
valid_preds.append(pred.strip()) |
|
|
|
|
|
if not valid_refs: |
|
print("Warning: No valid reference-prediction pairs found") |
|
return 100.0 |
|
|
|
|
|
similarities = [] |
|
for ref, pred in zip(valid_refs, valid_preds): |
|
try: |
|
|
|
vectorizer = CountVectorizer( |
|
analyzer='char_wb', |
|
ngram_range=(2, 3) |
|
) |
|
|
|
|
|
vectors = vectorizer.fit_transform([ref, pred]) |
|
|
|
|
|
similarity = cosine_similarity(vectors[0:1], vectors[1:2])[0][0] |
|
similarities.append(similarity * 100) |
|
|
|
except Exception as e: |
|
print(f"Error calculating similarity: {e}") |
|
similarities.append(0.0) |
|
|
|
|
|
avg_similarity = np.mean(similarities) |
|
min_similarity = np.min(similarities) |
|
max_similarity = np.max(similarities) |
|
error_rate = 100.0 - avg_similarity |
|
|
|
|
|
print(f"Similarity Statistics:") |
|
print(f" - Average: {avg_similarity:.2f}%") |
|
print(f" - Range: {min_similarity:.2f}% to {max_similarity:.2f}%") |
|
print(f" - Valid samples: {len(similarities)}/{len(references)}") |
|
|
|
return error_rate |
|
|
|
|
|
def create_wer_analysis_dataframe( |
|
references: List[str], |
|
predictions: List[str], |
|
normalized_references: Optional[List[str]] = None, |
|
normalized_predictions: Optional[List[str]] = None, |
|
output_csv: str = "wer_analysis.csv" |
|
) -> pd.DataFrame: |
|
""" |
|
Create a comprehensive DataFrame comparing reference and prediction texts with multiple metrics. |
|
|
|
For each sample, calculates: |
|
- Word Error Rate (WER) for original and normalized texts |
|
- Cosine similarity for original and normalized texts |
|
- Length statistics and differences |
|
|
|
Args: |
|
references: List of original reference texts |
|
predictions: List of original prediction texts |
|
normalized_references: Optional list of normalized reference texts |
|
normalized_predictions: Optional list of normalized prediction texts |
|
output_csv: Path to save results (None to skip saving) |
|
|
|
Returns: |
|
pd.DataFrame: Analysis results with one row per sample |
|
|
|
Example: |
|
>>> df = create_wer_analysis_dataframe( |
|
... references=["hello world"], |
|
... predictions=["hello there"], |
|
... output_csv="analysis.csv" |
|
... ) |
|
""" |
|
from jiwer import wer |
|
|
|
records = [] |
|
|
|
for i, (ref, pred) in enumerate(zip(references, predictions)): |
|
|
|
if not ref.strip() or not pred.strip(): |
|
continue |
|
|
|
|
|
norm_ref = normalized_references[i] if normalized_references else ref |
|
norm_pred = normalized_predictions[i] if normalized_predictions else pred |
|
|
|
|
|
metrics = { |
|
'index': i, |
|
'reference': ref, |
|
'prediction': pred, |
|
'normalized_reference': norm_ref, |
|
'normalized_prediction': norm_pred, |
|
'ref_length': len(ref.split()), |
|
'pred_length': len(pred.split()), |
|
'length_difference': len(pred.split()) - len(ref.split()) |
|
} |
|
|
|
|
|
try: |
|
metrics['wer'] = wer(ref, pred) * 100 |
|
metrics['normalized_wer'] = wer(norm_ref, norm_pred) * 100 |
|
except Exception as e: |
|
print(f"WER calculation failed for sample {i}: {e}") |
|
metrics.update({'wer': np.nan, 'normalized_wer': np.nan}) |
|
|
|
|
|
for prefix, text1, text2 in [ |
|
('', ref, pred), |
|
('normalized_', norm_ref, norm_pred) |
|
]: |
|
try: |
|
vectorizer = CountVectorizer( |
|
analyzer='char_wb', |
|
ngram_range=(2, 3) |
|
vectors = vectorizer.fit_transform([text1, text2]) |
|
similarity = cosine_similarity(vectors[0:1], vectors[1:2])[0][0] * 100 |
|
metrics[f'{prefix}similarity'] = similarity |
|
except Exception as e: |
|
print(f"Similarity calculation failed for sample {i}: {e}") |
|
metrics[f'{prefix}similarity'] = np.nan |
|
|
|
records.append(metrics) |
|
|
|
|
|
df = pd.DataFrame(records) |
|
|
|
|
|
if output_csv: |
|
try: |
|
df.to_csv(output_csv, index=False) |
|
print(f"Analysis saved to {output_csv}") |
|
except Exception as e: |
|
print(f"Failed to save CSV: {e}") |
|
|
|
return df |