Updated now with langauge codes
Browse files
app.py
CHANGED
@@ -28,13 +28,27 @@ def load_translation_models():
|
|
28 |
print("Loading English to Siswati model...")
|
29 |
en_ss_tokenizer = AutoTokenizer.from_pretrained("dsfsi/en-ss-m2m100-combo")
|
30 |
en_ss_model = AutoModelForSeq2SeqLM.from_pretrained("dsfsi/en-ss-m2m100-combo")
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
# Siswati to English
|
34 |
print("Loading Siswati to English model...")
|
35 |
ss_en_tokenizer = AutoTokenizer.from_pretrained("dsfsi/ss-en-m2m100-combo")
|
36 |
ss_en_model = AutoModelForSeq2SeqLM.from_pretrained("dsfsi/ss-en-m2m100-combo")
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
# Cache the models
|
40 |
_model_cache['en_ss_pipeline'] = en_ss_pipeline
|
@@ -56,6 +70,77 @@ def get_translators():
|
|
56 |
|
57 |
return _model_cache['en_ss_pipeline'], _model_cache['ss_en_pipeline']
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
def analyze_siswati_features(text):
|
60 |
"""Analyze Siswati-specific linguistic features"""
|
61 |
features = {}
|
@@ -117,32 +202,17 @@ def translate_text(text, direction):
|
|
117 |
start_time = time.time()
|
118 |
|
119 |
try:
|
120 |
-
#
|
121 |
-
|
|
|
|
|
|
|
|
|
122 |
|
123 |
-
#
|
124 |
if direction == "English → Siswati":
|
125 |
-
if en_ss_translator is None:
|
126 |
-
return "Translation model not loaded. Please try again.", "Model loading failed.", create_empty_metrics_table()
|
127 |
-
|
128 |
-
result = en_ss_translator(text, max_length=512)
|
129 |
-
translated_text = result[0]['translation_text']
|
130 |
-
|
131 |
-
# Analyze source (English) and target (Siswati)
|
132 |
-
source_metrics = calculate_linguistic_metrics(text)
|
133 |
-
target_metrics = calculate_linguistic_metrics(translated_text)
|
134 |
siswati_features = analyze_siswati_features(translated_text)
|
135 |
-
|
136 |
-
else: # Siswati → English
|
137 |
-
if ss_en_translator is None:
|
138 |
-
return "Translation model not loaded. Please try again.", "Model loading failed.", create_empty_metrics_table()
|
139 |
-
|
140 |
-
result = ss_en_translator(text, max_length=512)
|
141 |
-
translated_text = result[0]['translation_text']
|
142 |
-
|
143 |
-
# Analyze source (Siswati) and target (English)
|
144 |
-
source_metrics = calculate_linguistic_metrics(text)
|
145 |
-
target_metrics = calculate_linguistic_metrics(translated_text)
|
146 |
siswati_features = analyze_siswati_features(text)
|
147 |
|
148 |
processing_time = time.time() - start_time
|
@@ -279,23 +349,9 @@ def secure_file_processing(file_obj, direction):
|
|
279 |
if len(text) > 1000:
|
280 |
text = text[:1000] + "..."
|
281 |
|
282 |
-
#
|
283 |
-
en_ss_translator, ss_en_translator = get_translators()
|
284 |
-
|
285 |
-
# Perform translation based on direction
|
286 |
try:
|
287 |
-
|
288 |
-
if en_ss_translator is None:
|
289 |
-
translated = "Model not available"
|
290 |
-
else:
|
291 |
-
result = en_ss_translator(text, max_length=512)
|
292 |
-
translated = result[0]['translation_text']
|
293 |
-
else: # Siswati → English
|
294 |
-
if ss_en_translator is None:
|
295 |
-
translated = "Model not available"
|
296 |
-
else:
|
297 |
-
result = ss_en_translator(text, max_length=512)
|
298 |
-
translated = result[0]['translation_text']
|
299 |
except Exception as e:
|
300 |
translated = f"Translation error: {str(e)}"
|
301 |
|
@@ -356,7 +412,6 @@ def create_gradio_interface():
|
|
356 |
# Header Section
|
357 |
gr.HTML("""
|
358 |
<div class="main-header">
|
359 |
-
<img src="https://www.dsfsi.co.za/images/logo_transparent_expanded.png" width="400" alt="DSFSI Logo" style="margin-bottom: 1rem;">
|
360 |
<h1>🔬 Siswati-English Linguistic Translation Tool</h1>
|
361 |
<p style="font-size: 1.1em; color: #666; max-width: 800px; margin: 0 auto;">
|
362 |
Advanced AI-powered translation system with comprehensive linguistic analysis features,
|
|
|
28 |
print("Loading English to Siswati model...")
|
29 |
en_ss_tokenizer = AutoTokenizer.from_pretrained("dsfsi/en-ss-m2m100-combo")
|
30 |
en_ss_model = AutoModelForSeq2SeqLM.from_pretrained("dsfsi/en-ss-m2m100-combo")
|
31 |
+
# Fix: Add src_lang and tgt_lang parameters
|
32 |
+
en_ss_pipeline = pipeline(
|
33 |
+
"translation",
|
34 |
+
model=en_ss_model,
|
35 |
+
tokenizer=en_ss_tokenizer,
|
36 |
+
src_lang="en",
|
37 |
+
tgt_lang="ss"
|
38 |
+
)
|
39 |
|
40 |
# Siswati to English
|
41 |
print("Loading Siswati to English model...")
|
42 |
ss_en_tokenizer = AutoTokenizer.from_pretrained("dsfsi/ss-en-m2m100-combo")
|
43 |
ss_en_model = AutoModelForSeq2SeqLM.from_pretrained("dsfsi/ss-en-m2m100-combo")
|
44 |
+
# Fix: Add src_lang and tgt_lang parameters
|
45 |
+
ss_en_pipeline = pipeline(
|
46 |
+
"translation",
|
47 |
+
model=ss_en_model,
|
48 |
+
tokenizer=ss_en_tokenizer,
|
49 |
+
src_lang="ss",
|
50 |
+
tgt_lang="en"
|
51 |
+
)
|
52 |
|
53 |
# Cache the models
|
54 |
_model_cache['en_ss_pipeline'] = en_ss_pipeline
|
|
|
70 |
|
71 |
return _model_cache['en_ss_pipeline'], _model_cache['ss_en_pipeline']
|
72 |
|
73 |
+
def translate_with_fallback(text, direction):
|
74 |
+
"""Translation function with fallback method if pipeline fails"""
|
75 |
+
try:
|
76 |
+
# Get translators
|
77 |
+
en_ss_translator, ss_en_translator = get_translators()
|
78 |
+
|
79 |
+
if direction == "English → Siswati":
|
80 |
+
if en_ss_translator is None:
|
81 |
+
raise Exception("English to Siswati model not loaded")
|
82 |
+
|
83 |
+
# Try with pipeline first
|
84 |
+
try:
|
85 |
+
result = en_ss_translator(text, max_length=512)
|
86 |
+
return result[0]['translation_text']
|
87 |
+
except Exception as pipeline_error:
|
88 |
+
print(f"Pipeline failed, trying direct model approach: {pipeline_error}")
|
89 |
+
|
90 |
+
# Fallback: Use model directly
|
91 |
+
tokenizer = AutoTokenizer.from_pretrained("dsfsi/en-ss-m2m100-combo")
|
92 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("dsfsi/en-ss-m2m100-combo")
|
93 |
+
|
94 |
+
# Set language tokens
|
95 |
+
tokenizer.src_lang = "en"
|
96 |
+
encoded = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
|
97 |
+
|
98 |
+
# Force target language token
|
99 |
+
forced_bos_token_id = tokenizer.get_lang_id("ss")
|
100 |
+
|
101 |
+
with torch.no_grad():
|
102 |
+
generated_tokens = model.generate(
|
103 |
+
**encoded,
|
104 |
+
forced_bos_token_id=forced_bos_token_id,
|
105 |
+
max_length=512
|
106 |
+
)
|
107 |
+
|
108 |
+
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
109 |
+
|
110 |
+
else: # Siswati → English
|
111 |
+
if ss_en_translator is None:
|
112 |
+
raise Exception("Siswati to English model not loaded")
|
113 |
+
|
114 |
+
# Try with pipeline first
|
115 |
+
try:
|
116 |
+
result = ss_en_translator(text, max_length=512)
|
117 |
+
return result[0]['translation_text']
|
118 |
+
except Exception as pipeline_error:
|
119 |
+
print(f"Pipeline failed, trying direct model approach: {pipeline_error}")
|
120 |
+
|
121 |
+
# Fallback: Use model directly
|
122 |
+
tokenizer = AutoTokenizer.from_pretrained("dsfsi/ss-en-m2m100-combo")
|
123 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("dsfsi/ss-en-m2m100-combo")
|
124 |
+
|
125 |
+
# Set language tokens
|
126 |
+
tokenizer.src_lang = "ss"
|
127 |
+
encoded = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
|
128 |
+
|
129 |
+
# Force target language token
|
130 |
+
forced_bos_token_id = tokenizer.get_lang_id("en")
|
131 |
+
|
132 |
+
with torch.no_grad():
|
133 |
+
generated_tokens = model.generate(
|
134 |
+
**encoded,
|
135 |
+
forced_bos_token_id=forced_bos_token_id,
|
136 |
+
max_length=512
|
137 |
+
)
|
138 |
+
|
139 |
+
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
140 |
+
|
141 |
+
except Exception as e:
|
142 |
+
raise Exception(f"Translation failed: {str(e)}")
|
143 |
+
|
144 |
def analyze_siswati_features(text):
|
145 |
"""Analyze Siswati-specific linguistic features"""
|
146 |
features = {}
|
|
|
202 |
start_time = time.time()
|
203 |
|
204 |
try:
|
205 |
+
# Perform translation using the fallback method
|
206 |
+
translated_text = translate_with_fallback(text, direction)
|
207 |
+
|
208 |
+
# Analyze source and target text
|
209 |
+
source_metrics = calculate_linguistic_metrics(text)
|
210 |
+
target_metrics = calculate_linguistic_metrics(translated_text)
|
211 |
|
212 |
+
# Analyze Siswati features based on direction
|
213 |
if direction == "English → Siswati":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
siswati_features = analyze_siswati_features(translated_text)
|
215 |
+
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
siswati_features = analyze_siswati_features(text)
|
217 |
|
218 |
processing_time = time.time() - start_time
|
|
|
349 |
if len(text) > 1000:
|
350 |
text = text[:1000] + "..."
|
351 |
|
352 |
+
# Perform translation using the fallback method
|
|
|
|
|
|
|
353 |
try:
|
354 |
+
translated = translate_with_fallback(text, direction)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
except Exception as e:
|
356 |
translated = f"Translation error: {str(e)}"
|
357 |
|
|
|
412 |
# Header Section
|
413 |
gr.HTML("""
|
414 |
<div class="main-header">
|
|
|
415 |
<h1>🔬 Siswati-English Linguistic Translation Tool</h1>
|
416 |
<p style="font-size: 1.1em; color: #666; max-width: 800px; margin: 0 auto;">
|
417 |
Advanced AI-powered translation system with comprehensive linguistic analysis features,
|