quality-lens / blaser_sonar_space.py
Tristan Yu
Upload 7 files
cf3775c verified
raw
history blame
5.31 kB
#!/usr/bin/env python3
"""
BLASER 2.0-QE Implementation using sonar-space package
This implementation should give accurate scores matching the official results
"""
import torch
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
from sonar.models.blaser.loader import load_blaser_model
class BLASEREvaluator:
def __init__(self):
"""Initialize BLASER evaluator"""
self.text_embedder = None
self.blaser_model = None
self.initialized = False
def initialize(self):
"""Initialize models and pipelines"""
print("πŸš€ Initializing BLASER 2.0-QE...")
print("This may take a few minutes on first run as models are downloaded...")
try:
# Initialize text embedder with SONAR
print("πŸ“ Loading SONAR text embedder...")
self.text_embedder = TextToEmbeddingModelPipeline(
encoder="text_sonar_basic_encoder",
tokenizer="text_sonar_basic_encoder"
)
# Load BLASER model
print("🎯 Loading BLASER 2.0-QE model...")
self.blaser_model = load_blaser_model("blaser_2_0_qe").eval()
self.initialized = True
print("βœ… BLASER 2.0-QE initialized successfully!")
return True
except Exception as e:
print(f"❌ Initialization failed: {e}")
print("Try setting FAIRSEQ2_EXTENSION_TRACE=1 for more details")
return False
def evaluate(self, source_text: str, translation_text: str,
source_lang: str = "fra_Latn", target_lang: str = "eng_Latn") -> float:
"""
Evaluate translation quality using BLASER 2.0-QE
Args:
source_text: Source text
translation_text: Machine translation
source_lang: Source language code (default: fra_Latn)
target_lang: Target language code (default: eng_Latn)
Returns:
BLASER score (higher is better)
"""
if not self.initialized:
raise RuntimeError("BLASER not initialized. Call initialize() first.")
print(f"\nπŸ“Š Evaluating translation:")
print(f" Source ({source_lang}): {source_text}")
print(f" Translation ({target_lang}): {translation_text}")
# Generate embeddings using SONAR
print("πŸ”„ Generating embeddings...")
src_embs = self.text_embedder.predict([source_text], source_lang=source_lang)
mt_embs = self.text_embedder.predict([translation_text], source_lang=target_lang)
# Get BLASER score
print("πŸ”„ Computing BLASER score...")
with torch.inference_mode():
score = self.blaser_model(src=src_embs, mt=mt_embs).item()
print(f"✨ BLASER score: {score:.3f}")
return score
def main():
"""Example usage"""
# Initialize evaluator
evaluator = BLASEREvaluator()
if not evaluator.initialize():
print("Failed to initialize BLASER")
return
# Test cases with both directions
test_cases = [
# French-English pair
{
"source": "Le chat s'assit sur le tapis.",
"translation": "The cat sat down on the carpet.",
"source_lang": "fra_Latn",
"target_lang": "eng_Latn",
"name": "French β†’ English"
},
{
"source": "The cat sat down on the carpet.",
"translation": "Le chat s'assit sur le tapis.",
"source_lang": "eng_Latn",
"target_lang": "fra_Latn",
"name": "English β†’ French"
},
# English-English pair
{
"source": "The dog is running.",
"translation": "The dog runs.",
"source_lang": "eng_Latn",
"target_lang": "eng_Latn",
"name": "English β†’ English (present continuous β†’ simple)"
},
{
"source": "The dog runs.",
"translation": "The dog is running.",
"source_lang": "eng_Latn",
"target_lang": "eng_Latn",
"name": "English β†’ English (simple β†’ present continuous)"
},
# Spanish-English pair
{
"source": "El gato estΓ‘ sentado en la alfombra.",
"translation": "The cat is sitting on the carpet.",
"source_lang": "spa_Latn",
"target_lang": "eng_Latn",
"name": "Spanish β†’ English"
},
{
"source": "The cat is sitting on the carpet.",
"translation": "El gato estΓ‘ sentado en la alfombra.",
"source_lang": "eng_Latn",
"target_lang": "spa_Latn",
"name": "English β†’ Spanish"
}
]
print("\n=== Running BLASER evaluations in both directions ===\n")
for case in test_cases:
print(f"\nπŸ”„ Testing: {case['name']}")
score = evaluator.evaluate(
case["source"],
case["translation"],
case["source_lang"],
case["target_lang"]
)
print(f"πŸ“ˆ Final score: {score:.3f}")
print(" " + "="*50)
if __name__ == "__main__":
main()