vukosi's picture
Update app.py
65e0f55 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import pandas as pd
import time
import re
from datetime import datetime
import json
import tempfile
import os
import uuid
# Global model cache
_model_cache = {}
def load_translation_models():
"""Load and cache both translation models"""
global _model_cache
# Check if models are already cached
if 'en_ss_pipeline' in _model_cache and 'ss_en_pipeline' in _model_cache:
return _model_cache['en_ss_pipeline'], _model_cache['ss_en_pipeline']
try:
print("Loading translation models...")
# English to Siswati
print("Loading English to Siswati model...")
en_ss_tokenizer = AutoTokenizer.from_pretrained("dsfsi/en-ss-m2m100-combo")
en_ss_model = AutoModelForSeq2SeqLM.from_pretrained("dsfsi/en-ss-m2m100-combo")
# Fix: Add src_lang and tgt_lang parameters
en_ss_pipeline = pipeline(
"translation",
model=en_ss_model,
tokenizer=en_ss_tokenizer,
src_lang="en",
tgt_lang="ss"
)
# Siswati to English
print("Loading Siswati to English model...")
ss_en_tokenizer = AutoTokenizer.from_pretrained("dsfsi/ss-en-m2m100-combo")
ss_en_model = AutoModelForSeq2SeqLM.from_pretrained("dsfsi/ss-en-m2m100-combo")
# Fix: Add src_lang and tgt_lang parameters
ss_en_pipeline = pipeline(
"translation",
model=ss_en_model,
tokenizer=ss_en_tokenizer,
src_lang="ss",
tgt_lang="en"
)
# Cache the models
_model_cache['en_ss_pipeline'] = en_ss_pipeline
_model_cache['ss_en_pipeline'] = ss_en_pipeline
print("Models loaded successfully!")
return en_ss_pipeline, ss_en_pipeline
except Exception as e:
print(f"Error loading models: {e}")
return None, None
def get_translators():
"""Get cached translators, loading them if necessary"""
global _model_cache
if 'en_ss_pipeline' not in _model_cache or 'ss_en_pipeline' not in _model_cache:
return load_translation_models()
return _model_cache['en_ss_pipeline'], _model_cache['ss_en_pipeline']
def translate_with_fallback(text, direction):
"""Translation function with fallback method if pipeline fails"""
try:
# Get translators
en_ss_translator, ss_en_translator = get_translators()
if direction == "English β†’ Siswati":
if en_ss_translator is None:
raise Exception("English to Siswati model not loaded")
# Try with pipeline first
try:
result = en_ss_translator(text, max_length=512)
return result[0]['translation_text']
except Exception as pipeline_error:
print(f"Pipeline failed, trying direct model approach: {pipeline_error}")
# Fallback: Use model directly
tokenizer = AutoTokenizer.from_pretrained("dsfsi/en-ss-m2m100-combo")
model = AutoModelForSeq2SeqLM.from_pretrained("dsfsi/en-ss-m2m100-combo")
# Set language tokens
tokenizer.src_lang = "en"
encoded = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
# Force target language token
forced_bos_token_id = tokenizer.get_lang_id("ss")
with torch.no_grad():
generated_tokens = model.generate(
**encoded,
forced_bos_token_id=forced_bos_token_id,
max_length=512
)
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
else: # Siswati β†’ English
if ss_en_translator is None:
raise Exception("Siswati to English model not loaded")
# Try with pipeline first
try:
result = ss_en_translator(text, max_length=512)
return result[0]['translation_text']
except Exception as pipeline_error:
print(f"Pipeline failed, trying direct model approach: {pipeline_error}")
# Fallback: Use model directly
tokenizer = AutoTokenizer.from_pretrained("dsfsi/ss-en-m2m100-combo")
model = AutoModelForSeq2SeqLM.from_pretrained("dsfsi/ss-en-m2m100-combo")
# Set language tokens
tokenizer.src_lang = "ss"
encoded = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
# Force target language token
forced_bos_token_id = tokenizer.get_lang_id("en")
with torch.no_grad():
generated_tokens = model.generate(
**encoded,
forced_bos_token_id=forced_bos_token_id,
max_length=512
)
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
except Exception as e:
raise Exception(f"Translation failed: {str(e)}")
def analyze_siswati_features(text):
"""Analyze Siswati-specific linguistic features"""
features = {}
# Click consonants (c, q, x sounds)
click_pattern = r'[cqx]'
features['click_consonants'] = len(re.findall(click_pattern, text.lower()))
# Tone markers (acute and grave accents)
tone_pattern = r'[Ñàéèíìóòúù]'
features['tone_markers'] = len(re.findall(tone_pattern, text.lower()))
# Potential agglutination (words longer than 10 characters)
words = text.split()
long_words = [word for word in words if len(word) > 10]
features['potential_agglutination'] = len(long_words)
features['long_words'] = long_words[:5] # Show first 5 examples
return features
def calculate_linguistic_metrics(text):
"""Calculate comprehensive linguistic metrics"""
if not text.strip():
return {}
# Basic counts
char_count = len(text)
word_count = len(text.split())
sentence_count = len([s for s in re.split(r'[.!?]+', text) if s.strip()])
# Advanced metrics
words = text.split()
unique_words = set(words)
lexical_diversity = len(unique_words) / word_count if word_count > 0 else 0
avg_word_length = sum(len(word) for word in words) / word_count if word_count > 0 else 0
return {
'char_count': char_count,
'word_count': word_count,
'sentence_count': sentence_count,
'lexical_diversity': lexical_diversity,
'avg_word_length': avg_word_length,
'unique_words': len(unique_words)
}
def create_empty_metrics_table():
"""Create an empty metrics table for error cases"""
return pd.DataFrame({
'Metric': ['Words', 'Characters', 'Sentences', 'Unique Words', 'Avg Word Length', 'Lexical Diversity'],
'Source Text': [0, 0, 0, 0, '0.0', '0.000'],
'Target Text': [0, 0, 0, 0, '0.0', '0.000']
})
def translate_text(text, direction):
"""Main translation function with linguistic analysis"""
if not text.strip():
return "Please enter text to translate.", "No analysis available.", create_empty_metrics_table()
start_time = time.time()
try:
# Perform translation using the fallback method
translated_text = translate_with_fallback(text, direction)
# Analyze source and target text
source_metrics = calculate_linguistic_metrics(text)
target_metrics = calculate_linguistic_metrics(translated_text)
# Analyze Siswati features based on direction
if direction == "English β†’ Siswati":
siswati_features = analyze_siswati_features(translated_text)
else:
siswati_features = analyze_siswati_features(text)
processing_time = time.time() - start_time
# Create linguistic analysis report
analysis_report = create_analysis_report(
source_metrics, target_metrics, siswati_features,
processing_time, direction
)
# Create metrics table
metrics_table = create_metrics_table(source_metrics, target_metrics, processing_time)
return translated_text, analysis_report, metrics_table
except Exception as e:
return f"Translation error: {str(e)}", f"Analysis failed: {str(e)}", create_empty_metrics_table()
def create_analysis_report(source_metrics, target_metrics, siswati_features, processing_time, direction):
"""Create a comprehensive linguistic analysis report"""
report = f"""
## πŸ“Š Linguistic Analysis Report
### Translation Details
- **Direction**: {direction}
- **Processing Time**: {processing_time:.2f} seconds
### Text Complexity Metrics
| Metric | Source | Target | Ratio |
|--------|--------|--------|-------|
| Word Count | {source_metrics.get('word_count', 0)} | {target_metrics.get('word_count', 0)} | {target_metrics.get('word_count', 0) / max(source_metrics.get('word_count', 1), 1):.2f} |
| Character Count | {source_metrics.get('char_count', 0)} | {target_metrics.get('char_count', 0)} | {target_metrics.get('char_count', 0) / max(source_metrics.get('char_count', 1), 1):.2f} |
| Sentence Count | {source_metrics.get('sentence_count', 0)} | {target_metrics.get('sentence_count', 0)} | {target_metrics.get('sentence_count', 0) / max(source_metrics.get('sentence_count', 1), 1):.2f} |
| Avg Word Length | {source_metrics.get('avg_word_length', 0):.1f} | {target_metrics.get('avg_word_length', 0):.1f} | {target_metrics.get('avg_word_length', 0) / max(source_metrics.get('avg_word_length', 1), 1):.2f} |
| Lexical Diversity | {source_metrics.get('lexical_diversity', 0):.3f} | {target_metrics.get('lexical_diversity', 0):.3f} | {target_metrics.get('lexical_diversity', 0) / max(source_metrics.get('lexical_diversity', 0.001), 0.001):.2f} |
### Siswati-Specific Features
- **Click Consonants**: {siswati_features.get('click_consonants', 0)} detected
- **Tone Markers**: {siswati_features.get('tone_markers', 0)} detected
- **Potential Agglutination**: {siswati_features.get('potential_agglutination', 0)} words longer than 10 characters
"""
if siswati_features.get('long_words'):
report += f"- **Long Word Examples**: {', '.join(siswati_features['long_words'])}\n"
return report
def create_metrics_table(source_metrics, target_metrics, processing_time):
"""Create a DataFrame for metrics visualization"""
data = {
'Metric': ['Words', 'Characters', 'Sentences', 'Unique Words', 'Avg Word Length', 'Lexical Diversity'],
'Source Text': [
source_metrics.get('word_count', 0),
source_metrics.get('char_count', 0),
source_metrics.get('sentence_count', 0),
source_metrics.get('unique_words', 0),
f"{source_metrics.get('avg_word_length', 0):.1f}",
f"{source_metrics.get('lexical_diversity', 0):.3f}"
],
'Target Text': [
target_metrics.get('word_count', 0),
target_metrics.get('char_count', 0),
target_metrics.get('sentence_count', 0),
target_metrics.get('unique_words', 0),
f"{target_metrics.get('avg_word_length', 0):.1f}",
f"{target_metrics.get('lexical_diversity', 0):.3f}"
]
}
return pd.DataFrame(data)
def secure_file_processing(file_obj, direction):
"""Securely process uploaded files with proper cleanup"""
if file_obj is None:
return "Please upload a file.", pd.DataFrame()
# Create a unique temporary directory for this processing session
session_id = str(uuid.uuid4())
temp_dir = None
try:
# Create secure temporary directory
temp_dir = tempfile.mkdtemp(prefix=f"translation_{session_id}_")
# Get file extension and validate
file_ext = os.path.splitext(file_obj.name)[1].lower()
if file_ext not in ['.txt', '.csv']:
return "Only .txt and .csv files are supported.", pd.DataFrame()
# Create secure temporary file path
temp_file_path = os.path.join(temp_dir, f"upload_{session_id}{file_ext}")
# Copy uploaded file to secure location
import shutil
shutil.copy2(file_obj.name, temp_file_path)
# Process file based on type
texts = []
if file_ext == '.csv':
try:
df = pd.read_csv(temp_file_path)
if df.empty:
return "The uploaded CSV file is empty.", pd.DataFrame()
# Assume first column contains text to translate
texts = df.iloc[:, 0].dropna().astype(str).tolist()
except Exception as e:
return f"Error reading CSV file: {str(e)}", pd.DataFrame()
else: # .txt file
try:
with open(temp_file_path, 'r', encoding='utf-8') as f:
content = f.read()
texts = [line.strip() for line in content.split('\n') if line.strip()]
except Exception as e:
return f"Error reading text file: {str(e)}", pd.DataFrame()
if not texts:
return "No text found in the uploaded file.", pd.DataFrame()
# Limit batch size for performance and security
max_batch_size = 10
if len(texts) > max_batch_size:
texts = texts[:max_batch_size]
warning_msg = f"Processing limited to first {max_batch_size} entries for security and performance reasons."
else:
warning_msg = ""
# Process translations
results = []
for i, text in enumerate(texts):
if len(text.strip()) == 0:
continue
# Limit individual text length for security
if len(text) > 1000:
text = text[:1000] + "..."
# Perform translation using the fallback method
try:
translated = translate_with_fallback(text, direction)
except Exception as e:
translated = f"Translation error: {str(e)}"
results.append({
'Index': i + 1,
'Original': text[:100] + '...' if len(text) > 100 else text,
'Translation': translated[:100] + '...' if len(translated) > 100 else translated
})
if not results:
return "No valid text entries found to translate.", pd.DataFrame()
results_df = pd.DataFrame(results)
summary = f"Successfully processed {len(results)} text entries."
if warning_msg:
summary = f"{summary} {warning_msg}"
return summary, results_df
except Exception as e:
return f"Error processing file: {str(e)}", pd.DataFrame()
finally:
# Clean up temporary files and directory
if temp_dir and os.path.exists(temp_dir):
try:
import shutil
shutil.rmtree(temp_dir)
except Exception as e:
print(f"Warning: Could not clean up temporary directory: {e}")
# Define example texts
TRANSLATION_EXAMPLES = [
["English β†’ Siswati", "Hello, how are you today?"],
["English β†’ Siswati", "The weather is beautiful this morning."],
["English β†’ Siswati", "I am learning Siswati language."],
["English β†’ Siswati", "Thank you for your help."],
["Siswati β†’ English", "Sawubona, unjani namuhla?"],
["Siswati β†’ English", "Siyabonga ngekusita kwakho."],
["Siswati β†’ English", "Lolu luhle kakhulu."],
["Siswati β†’ English", "Ngiyakuthanda."]
]
def create_gradio_interface():
"""Create the main Gradio interface with security measures"""
with gr.Blocks(
title="πŸ”¬ Siswati-English Linguistic Translation Tool",
theme=gr.themes.Soft(),
css="""
.gradio-container {font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;}
.main-header {text-align: center; padding: 2rem 0;}
.dsfsi-logo {text-align: center; margin-bottom: 1rem;}
.dsfsi-logo img {max-width: 300px; height: auto;}
.metric-table {font-size: 0.9em;}
.feature-highlight {background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); color: white; padding: 1rem; border-radius: 10px; margin: 1rem 0;}
"""
) as demo:
# Header Section with DSFSI Logo
gr.HTML("""
<div class="dsfsi-logo">
<img src="https://www.dsfsi.co.za/images/logo_transparent_expanded.png" alt="DSFSI Logo" />
</div>
<div class="main-header">
<h1>πŸ”¬ Siswati-English Linguistic Translation Tool</h1>
<p style="font-size: 1.1em; color: #666; max-width: 800px; margin: 0 auto;">
Advanced AI-powered translation system with comprehensive linguistic analysis features,
designed specifically for linguists, researchers, and language documentation projects.
</p>
</div>
""")
# Main Content Tabs
with gr.Tabs():
# Single Translation Tab
with gr.Tab("🌐 Translation & Analysis"):
gr.Markdown("""
### Real-time Translation with Linguistic Analysis
Translate between English and Siswati while getting detailed linguistic insights including morphological complexity, lexical diversity, and Siswati-specific features.
""")
with gr.Row():
with gr.Column(scale=1):
direction = gr.Dropdown(
choices=["English β†’ Siswati", "Siswati β†’ English"],
label="Translation Direction",
value="English β†’ Siswati"
)
input_text = gr.Textbox(
label="Input Text",
placeholder="Enter text to translate...",
lines=4,
max_lines=10
)
translate_btn = gr.Button("πŸ”„ Translate & Analyze", variant="primary", size="lg")
with gr.Column(scale=1):
output_text = gr.Textbox(
label="Translation",
lines=4,
interactive=False
)
# Examples Section
gr.Markdown("### πŸ“š Example Translations")
gr.Examples(
examples=TRANSLATION_EXAMPLES,
inputs=[direction, input_text],
label="Click an example to try it:"
)
# Analysis Results
with gr.Accordion("πŸ“Š Detailed Linguistic Analysis", open=False):
analysis_output = gr.Markdown(label="Analysis Report")
with gr.Accordion("πŸ“ˆ Metrics Table", open=False):
metrics_table = gr.Dataframe(
label="Comparative Metrics",
headers=["Metric", "Source Text", "Target Text"],
interactive=False
)
# Connect translation function
translate_btn.click(
fn=translate_text,
inputs=[input_text, direction],
outputs=[output_text, analysis_output, metrics_table]
)
# Batch Processing Tab
with gr.Tab("πŸ“ Batch Processing"):
gr.Markdown("""
### Secure Corpus Analysis & Batch Translation
Upload text files or CSV files for batch translation and corpus analysis. Files are processed securely and temporarily.
**Security Features:**
- Files are processed in isolated temporary directories
- No file persistence or history
- Automatic cleanup after processing
- Limited to first 10 entries for performance
""")
with gr.Row():
with gr.Column():
batch_direction = gr.Dropdown(
choices=["English β†’ Siswati", "Siswati β†’ English"],
label="Translation Direction",
value="English β†’ Siswati"
)
file_upload = gr.File(
label="Upload File (Max 5MB)",
file_types=[".txt", ".csv"],
type="filepath",
file_count="single"
)
batch_btn = gr.Button("πŸ”„ Process Batch", variant="primary")
gr.Markdown("""
**Supported formats:**
- `.txt` files: One text per line
- `.csv` files: Text in first column
- **Security limits**: Max 10 entries, 1000 chars per text
- **Privacy**: Files are automatically deleted after processing
""")
with gr.Column():
batch_summary = gr.Textbox(
label="Processing Summary",
lines=3,
interactive=False
)
batch_results = gr.Dataframe(
label="Translation Results",
interactive=False,
wrap=True
)
batch_btn.click(
fn=secure_file_processing,
inputs=[file_upload, batch_direction],
outputs=[batch_summary, batch_results]
)
# Research Tools Tab
with gr.Tab("πŸ”¬ Research Tools"):
gr.Markdown("""
### Advanced Linguistic Analysis Tools
Explore detailed linguistic features without data persistence.
""")
with gr.Row():
with gr.Column():
research_text = gr.Textbox(
label="Text for Analysis",
lines=6,
placeholder="Enter Siswati or English text for detailed analysis...",
max_lines=15
)
analyze_btn = gr.Button("πŸ” Analyze Text", variant="primary")
with gr.Column():
research_output = gr.JSON(
label="Detailed Analysis Results"
)
def detailed_analysis(text):
"""Perform detailed linguistic analysis without storing data"""
if not text.strip():
return {}
# Limit text length for security
if len(text) > 2000:
text = text[:2000] + "..."
metrics = calculate_linguistic_metrics(text)
siswati_features = analyze_siswati_features(text)
# Return analysis without sensitive information
return {
"basic_metrics": metrics,
"siswati_features": siswati_features,
"text_length": len(text),
"analysis_completed": True
}
analyze_btn.click(
fn=detailed_analysis,
inputs=research_text,
outputs=research_output
)
# Language Information
gr.Markdown("""
### πŸ—£οΈ About Siswati Language
**Siswati** (also known as **Swati** or **Swazi**) is a Bantu language spoken by approximately 2.3 million people, primarily in:
- πŸ‡ΈπŸ‡Ώ **Eswatini** (Kingdom of Eswatini) - Official language
- πŸ‡ΏπŸ‡¦ **South Africa** - One of 11 official languages
**Key Linguistic Features:**
- **Language Family**: Niger-Congo β†’ Bantu β†’ Southeast Bantu
- **Script**: Latin alphabet
- **Characteristics**: Agglutinative morphology, click consonants, tonal
- **ISO Code**: ss (ISO 639-1), ssw (ISO 639-3)
""")
# Footer Section
gr.Markdown("""
---
### πŸ“š Model Information & Citation
**Models Used:**
- **English β†’ Siswati**: [`dsfsi/en-ss-m2m100-combo`](https://huggingface.co/dsfsi/en-ss-m2m100-combo)
- **Siswati β†’ English**: [`dsfsi/ss-en-m2m100-combo`](https://huggingface.co/dsfsi/ss-en-m2m100-combo)
Both models are based on Meta's M2M100 architecture, fine-tuned specifically for Siswati-English translation pairs by the **Data Science for Social Impact Research Group**.
**Training Data**: Models trained on the Vuk'uzenzele and ZA-gov-multilingual South African corpora.
### πŸ”’ Privacy & Security
- No conversation history is stored
- Uploaded files are automatically deleted after processing
- All processing happens in isolated temporary environments
- No user data persistence
### πŸ™ Acknowledgments
We thank **Thapelo Sindanie** and **Unarine Netshifhefhe** for their contributions to this work.
### πŸ“– Citation
```bibtex
@inproceedings{lastrucci2023preparing,
title={Preparing the Vuk'uzenzele and ZA-gov-multilingual South African multilingual corpora},
author={Lastrucci, Richard and Rajab, Jenalea and Shingange, Matimba and Njini, Daniel and Marivate, Vukosi},
booktitle={Proceedings of the Fourth workshop on Resources for African Indigenous Languages (RAIL 2023)},
pages={18--25},
year={2023}
}
```
**Links**:
- [DSFSI](https://www.dsfsi.co.za/)
- [En→Ss Model](https://huggingface.co/dsfsi/en-ss-m2m100-combo) | [Ss→En Model](https://huggingface.co/dsfsi/ss-en-m2m100-combo)
- [Vuk'uzenzele Data](https://github.com/dsfsi/vukuzenzele-nlp) | [ZA-gov Data](https://github.com/dsfsi/gov-za-multilingual)
- [Research Feedback](https://docs.google.com/forms/d/e/1FAIpQLSf7S36dyAUPx2egmXbFpnTBuzoRulhL5Elu-N1eoMhaO7v10w/viewform)
---
**Built with ❀️ for the African NLP community**
""")
return demo
# Create and launch the interface
if __name__ == "__main__":
demo = create_gradio_interface()
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860,
show_error=True
)