Automatic Speech Recognition
Transformers
Safetensors
Swahili
English
whisper
Generated from Trainer
ASR-STT / Custom WER.py
Jacaranda's picture
Upload Custom WER.py
d657b96 verified
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from typing import List, Optional
import pandas as pd
def cosine_sim_wer(references: List[str], predictions: List[str]) -> float:
"""
Calculate a WER-like metric based on cosine similarity between reference and prediction texts.
This function computes character-level n-gram similarities between each reference-prediction pair
and returns an error rate (100% - average similarity). Handles empty inputs and provides
detailed similarity statistics.
Args:
references: List of reference transcript strings
predictions: List of model prediction strings
Returns:
float: Error rate based on cosine similarity (100% - average similarity)
Example:
>>> references = ["hello world", "good morning"]
>>> predictions = ["hello world", "good evening"]
>>> error_rate = cosine_sim_wer(references, predictions)
"""
# Validate and clean inputs
valid_refs, valid_preds = [], []
for ref, pred in zip(references, predictions):
if not ref.strip() or not pred.strip():
continue # Skip empty strings
valid_refs.append(ref.strip())
valid_preds.append(pred.strip())
# Handle case with no valid pairs
if not valid_refs:
print("Warning: No valid reference-prediction pairs found")
return 100.0 # Maximum error if no valid data
# Calculate pairwise similarities
similarities = []
for ref, pred in zip(valid_refs, valid_preds):
try:
# Use character-level n-grams (2-3 chars) for robust comparison
vectorizer = CountVectorizer(
analyzer='char_wb', # Word-boundary aware character n-grams
ngram_range=(2, 3) # Bigrams and trigrams
)
# Create document-term matrices
vectors = vectorizer.fit_transform([ref, pred])
# Compute cosine similarity
similarity = cosine_similarity(vectors[0:1], vectors[1:2])[0][0]
similarities.append(similarity * 100) # Convert to percentage
except Exception as e:
print(f"Error calculating similarity: {e}")
similarities.append(0.0) # Default to 0% similarity on error
# Compute statistics
avg_similarity = np.mean(similarities)
min_similarity = np.min(similarities)
max_similarity = np.max(similarities)
error_rate = 100.0 - avg_similarity
# Print diagnostics
print(f"Similarity Statistics:")
print(f" - Average: {avg_similarity:.2f}%")
print(f" - Range: {min_similarity:.2f}% to {max_similarity:.2f}%")
print(f" - Valid samples: {len(similarities)}/{len(references)}")
return error_rate
def create_wer_analysis_dataframe(
references: List[str],
predictions: List[str],
normalized_references: Optional[List[str]] = None,
normalized_predictions: Optional[List[str]] = None,
output_csv: str = "wer_analysis.csv"
) -> pd.DataFrame:
"""
Create a comprehensive DataFrame comparing reference and prediction texts with multiple metrics.
For each sample, calculates:
- Word Error Rate (WER) for original and normalized texts
- Cosine similarity for original and normalized texts
- Length statistics and differences
Args:
references: List of original reference texts
predictions: List of original prediction texts
normalized_references: Optional list of normalized reference texts
normalized_predictions: Optional list of normalized prediction texts
output_csv: Path to save results (None to skip saving)
Returns:
pd.DataFrame: Analysis results with one row per sample
Example:
>>> df = create_wer_analysis_dataframe(
... references=["hello world"],
... predictions=["hello there"],
... output_csv="analysis.csv"
... )
"""
from jiwer import wer # Import here to avoid dependency if not using WER
records = []
for i, (ref, pred) in enumerate(zip(references, predictions)):
# Skip empty samples
if not ref.strip() or not pred.strip():
continue
# Get normalized versions if provided
norm_ref = normalized_references[i] if normalized_references else ref
norm_pred = normalized_predictions[i] if normalized_predictions else pred
# Calculate metrics with error handling
metrics = {
'index': i,
'reference': ref,
'prediction': pred,
'normalized_reference': norm_ref,
'normalized_prediction': norm_pred,
'ref_length': len(ref.split()),
'pred_length': len(pred.split()),
'length_difference': len(pred.split()) - len(ref.split())
}
# Calculate WER metrics
try:
metrics['wer'] = wer(ref, pred) * 100
metrics['normalized_wer'] = wer(norm_ref, norm_pred) * 100
except Exception as e:
print(f"WER calculation failed for sample {i}: {e}")
metrics.update({'wer': np.nan, 'normalized_wer': np.nan})
# Calculate cosine similarities
for prefix, text1, text2 in [
('', ref, pred),
('normalized_', norm_ref, norm_pred)
]:
try:
vectorizer = CountVectorizer(
analyzer='char_wb',
ngram_range=(2, 3)
vectors = vectorizer.fit_transform([text1, text2])
similarity = cosine_similarity(vectors[0:1], vectors[1:2])[0][0] * 100
metrics[f'{prefix}similarity'] = similarity
except Exception as e:
print(f"Similarity calculation failed for sample {i}: {e}")
metrics[f'{prefix}similarity'] = np.nan
records.append(metrics)
# Create DataFrame
df = pd.DataFrame(records)
# Save to CSV if requested
if output_csv:
try:
df.to_csv(output_csv, index=False)
print(f"Analysis saved to {output_csv}")
except Exception as e:
print(f"Failed to save CSV: {e}")
return df