vukosi commited on
Commit
e1a0d4b
·
verified ·
1 Parent(s): df987a3

Updated now with langauge codes

Browse files
Files changed (1) hide show
  1. app.py +97 -42
app.py CHANGED
@@ -28,13 +28,27 @@ def load_translation_models():
28
  print("Loading English to Siswati model...")
29
  en_ss_tokenizer = AutoTokenizer.from_pretrained("dsfsi/en-ss-m2m100-combo")
30
  en_ss_model = AutoModelForSeq2SeqLM.from_pretrained("dsfsi/en-ss-m2m100-combo")
31
- en_ss_pipeline = pipeline("translation", model=en_ss_model, tokenizer=en_ss_tokenizer)
 
 
 
 
 
 
 
32
 
33
  # Siswati to English
34
  print("Loading Siswati to English model...")
35
  ss_en_tokenizer = AutoTokenizer.from_pretrained("dsfsi/ss-en-m2m100-combo")
36
  ss_en_model = AutoModelForSeq2SeqLM.from_pretrained("dsfsi/ss-en-m2m100-combo")
37
- ss_en_pipeline = pipeline("translation", model=ss_en_model, tokenizer=ss_en_tokenizer)
 
 
 
 
 
 
 
38
 
39
  # Cache the models
40
  _model_cache['en_ss_pipeline'] = en_ss_pipeline
@@ -56,6 +70,77 @@ def get_translators():
56
 
57
  return _model_cache['en_ss_pipeline'], _model_cache['ss_en_pipeline']
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def analyze_siswati_features(text):
60
  """Analyze Siswati-specific linguistic features"""
61
  features = {}
@@ -117,32 +202,17 @@ def translate_text(text, direction):
117
  start_time = time.time()
118
 
119
  try:
120
- # Get translators (will load if not cached)
121
- en_ss_translator, ss_en_translator = get_translators()
 
 
 
 
122
 
123
- # Perform translation
124
  if direction == "English → Siswati":
125
- if en_ss_translator is None:
126
- return "Translation model not loaded. Please try again.", "Model loading failed.", create_empty_metrics_table()
127
-
128
- result = en_ss_translator(text, max_length=512)
129
- translated_text = result[0]['translation_text']
130
-
131
- # Analyze source (English) and target (Siswati)
132
- source_metrics = calculate_linguistic_metrics(text)
133
- target_metrics = calculate_linguistic_metrics(translated_text)
134
  siswati_features = analyze_siswati_features(translated_text)
135
-
136
- else: # Siswati → English
137
- if ss_en_translator is None:
138
- return "Translation model not loaded. Please try again.", "Model loading failed.", create_empty_metrics_table()
139
-
140
- result = ss_en_translator(text, max_length=512)
141
- translated_text = result[0]['translation_text']
142
-
143
- # Analyze source (Siswati) and target (English)
144
- source_metrics = calculate_linguistic_metrics(text)
145
- target_metrics = calculate_linguistic_metrics(translated_text)
146
  siswati_features = analyze_siswati_features(text)
147
 
148
  processing_time = time.time() - start_time
@@ -279,23 +349,9 @@ def secure_file_processing(file_obj, direction):
279
  if len(text) > 1000:
280
  text = text[:1000] + "..."
281
 
282
- # Get translators for batch processing
283
- en_ss_translator, ss_en_translator = get_translators()
284
-
285
- # Perform translation based on direction
286
  try:
287
- if direction == "English → Siswati":
288
- if en_ss_translator is None:
289
- translated = "Model not available"
290
- else:
291
- result = en_ss_translator(text, max_length=512)
292
- translated = result[0]['translation_text']
293
- else: # Siswati → English
294
- if ss_en_translator is None:
295
- translated = "Model not available"
296
- else:
297
- result = ss_en_translator(text, max_length=512)
298
- translated = result[0]['translation_text']
299
  except Exception as e:
300
  translated = f"Translation error: {str(e)}"
301
 
@@ -356,7 +412,6 @@ def create_gradio_interface():
356
  # Header Section
357
  gr.HTML("""
358
  <div class="main-header">
359
- <img src="https://www.dsfsi.co.za/images/logo_transparent_expanded.png" width="400" alt="DSFSI Logo" style="margin-bottom: 1rem;">
360
  <h1>🔬 Siswati-English Linguistic Translation Tool</h1>
361
  <p style="font-size: 1.1em; color: #666; max-width: 800px; margin: 0 auto;">
362
  Advanced AI-powered translation system with comprehensive linguistic analysis features,
 
28
  print("Loading English to Siswati model...")
29
  en_ss_tokenizer = AutoTokenizer.from_pretrained("dsfsi/en-ss-m2m100-combo")
30
  en_ss_model = AutoModelForSeq2SeqLM.from_pretrained("dsfsi/en-ss-m2m100-combo")
31
+ # Fix: Add src_lang and tgt_lang parameters
32
+ en_ss_pipeline = pipeline(
33
+ "translation",
34
+ model=en_ss_model,
35
+ tokenizer=en_ss_tokenizer,
36
+ src_lang="en",
37
+ tgt_lang="ss"
38
+ )
39
 
40
  # Siswati to English
41
  print("Loading Siswati to English model...")
42
  ss_en_tokenizer = AutoTokenizer.from_pretrained("dsfsi/ss-en-m2m100-combo")
43
  ss_en_model = AutoModelForSeq2SeqLM.from_pretrained("dsfsi/ss-en-m2m100-combo")
44
+ # Fix: Add src_lang and tgt_lang parameters
45
+ ss_en_pipeline = pipeline(
46
+ "translation",
47
+ model=ss_en_model,
48
+ tokenizer=ss_en_tokenizer,
49
+ src_lang="ss",
50
+ tgt_lang="en"
51
+ )
52
 
53
  # Cache the models
54
  _model_cache['en_ss_pipeline'] = en_ss_pipeline
 
70
 
71
  return _model_cache['en_ss_pipeline'], _model_cache['ss_en_pipeline']
72
 
73
+ def translate_with_fallback(text, direction):
74
+ """Translation function with fallback method if pipeline fails"""
75
+ try:
76
+ # Get translators
77
+ en_ss_translator, ss_en_translator = get_translators()
78
+
79
+ if direction == "English → Siswati":
80
+ if en_ss_translator is None:
81
+ raise Exception("English to Siswati model not loaded")
82
+
83
+ # Try with pipeline first
84
+ try:
85
+ result = en_ss_translator(text, max_length=512)
86
+ return result[0]['translation_text']
87
+ except Exception as pipeline_error:
88
+ print(f"Pipeline failed, trying direct model approach: {pipeline_error}")
89
+
90
+ # Fallback: Use model directly
91
+ tokenizer = AutoTokenizer.from_pretrained("dsfsi/en-ss-m2m100-combo")
92
+ model = AutoModelForSeq2SeqLM.from_pretrained("dsfsi/en-ss-m2m100-combo")
93
+
94
+ # Set language tokens
95
+ tokenizer.src_lang = "en"
96
+ encoded = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
97
+
98
+ # Force target language token
99
+ forced_bos_token_id = tokenizer.get_lang_id("ss")
100
+
101
+ with torch.no_grad():
102
+ generated_tokens = model.generate(
103
+ **encoded,
104
+ forced_bos_token_id=forced_bos_token_id,
105
+ max_length=512
106
+ )
107
+
108
+ return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
109
+
110
+ else: # Siswati → English
111
+ if ss_en_translator is None:
112
+ raise Exception("Siswati to English model not loaded")
113
+
114
+ # Try with pipeline first
115
+ try:
116
+ result = ss_en_translator(text, max_length=512)
117
+ return result[0]['translation_text']
118
+ except Exception as pipeline_error:
119
+ print(f"Pipeline failed, trying direct model approach: {pipeline_error}")
120
+
121
+ # Fallback: Use model directly
122
+ tokenizer = AutoTokenizer.from_pretrained("dsfsi/ss-en-m2m100-combo")
123
+ model = AutoModelForSeq2SeqLM.from_pretrained("dsfsi/ss-en-m2m100-combo")
124
+
125
+ # Set language tokens
126
+ tokenizer.src_lang = "ss"
127
+ encoded = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
128
+
129
+ # Force target language token
130
+ forced_bos_token_id = tokenizer.get_lang_id("en")
131
+
132
+ with torch.no_grad():
133
+ generated_tokens = model.generate(
134
+ **encoded,
135
+ forced_bos_token_id=forced_bos_token_id,
136
+ max_length=512
137
+ )
138
+
139
+ return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
140
+
141
+ except Exception as e:
142
+ raise Exception(f"Translation failed: {str(e)}")
143
+
144
  def analyze_siswati_features(text):
145
  """Analyze Siswati-specific linguistic features"""
146
  features = {}
 
202
  start_time = time.time()
203
 
204
  try:
205
+ # Perform translation using the fallback method
206
+ translated_text = translate_with_fallback(text, direction)
207
+
208
+ # Analyze source and target text
209
+ source_metrics = calculate_linguistic_metrics(text)
210
+ target_metrics = calculate_linguistic_metrics(translated_text)
211
 
212
+ # Analyze Siswati features based on direction
213
  if direction == "English → Siswati":
 
 
 
 
 
 
 
 
 
214
  siswati_features = analyze_siswati_features(translated_text)
215
+ else:
 
 
 
 
 
 
 
 
 
 
216
  siswati_features = analyze_siswati_features(text)
217
 
218
  processing_time = time.time() - start_time
 
349
  if len(text) > 1000:
350
  text = text[:1000] + "..."
351
 
352
+ # Perform translation using the fallback method
 
 
 
353
  try:
354
+ translated = translate_with_fallback(text, direction)
 
 
 
 
 
 
 
 
 
 
 
355
  except Exception as e:
356
  translated = f"Translation error: {str(e)}"
357
 
 
412
  # Header Section
413
  gr.HTML("""
414
  <div class="main-header">
 
415
  <h1>🔬 Siswati-English Linguistic Translation Tool</h1>
416
  <p style="font-size: 1.1em; color: #666; max-width: 800px; margin: 0 auto;">
417
  Advanced AI-powered translation system with comprehensive linguistic analysis features,