techindia2025 commited on
Commit
d7ab2f5
·
verified ·
1 Parent(s): 804953d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -82
app.py CHANGED
@@ -2,18 +2,18 @@ import spaces
2
  import gradio as gr
3
  import torch
4
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5
- import re
6
 
7
  # Model configurations
8
  INDIC_EN_MODEL = "ai4bharat/indictrans2-indic-en-1B"
9
  EN_INDIC_MODEL = "ai4bharat/indictrans2-en-indic-1B"
10
 
11
- # Load tokenizers and models on CPU
12
- print("Loading IndicTrans2 tokenizers...")
13
  indic_en_tokenizer = AutoTokenizer.from_pretrained(INDIC_EN_MODEL, trust_remote_code=True)
14
  en_indic_tokenizer = AutoTokenizer.from_pretrained(EN_INDIC_MODEL, trust_remote_code=True)
15
 
16
- print("Loading IndicTrans2 models on CPU...")
17
  indic_en_model = AutoModelForSeq2SeqLM.from_pretrained(
18
  INDIC_EN_MODEL,
19
  trust_remote_code=True,
@@ -28,7 +28,10 @@ en_indic_model = AutoModelForSeq2SeqLM.from_pretrained(
28
  device_map="cpu"
29
  )
30
 
31
- # Language mappings for IndicTrans2
 
 
 
32
  LANGUAGE_CODES = {
33
  "Assamese": "asm_Beng",
34
  "Bengali": "ben_Beng",
@@ -58,22 +61,9 @@ LANGUAGE_CODES = {
58
  "English": "eng_Latn"
59
  }
60
 
61
- def format_input_for_indictrans2(text, src_lang, tgt_lang, direction):
62
- """Format input text according to IndicTrans2 requirements"""
63
- text = text.strip()
64
-
65
- if direction == "en_to_indic":
66
- # For English to Indic: format as "text </s> src_lang"
67
- formatted_input = f"{text} </s> {src_lang}"
68
- else: # indic_to_en
69
- # For Indic to English: format as "text </s> src_lang"
70
- formatted_input = f"{text} </s> {src_lang}"
71
-
72
- return formatted_input
73
-
74
  @spaces.GPU(duration=120)
75
  def translate_text(input_text, source_lang, target_lang, max_length):
76
- """Translate text using IndicTrans2 models"""
77
 
78
  if not input_text.strip():
79
  return "Please enter text to translate."
@@ -85,101 +75,86 @@ def translate_text(input_text, source_lang, target_lang, max_length):
85
  src_code = LANGUAGE_CODES[source_lang]
86
  tgt_code = LANGUAGE_CODES[target_lang]
87
 
88
- # Determine direction and model
89
  if source_lang == "English" and target_lang != "English":
90
- # English to Indic translation
91
  model_gpu = en_indic_model.to(device)
92
  tokenizer = en_indic_tokenizer
93
  direction = "en_to_indic"
94
  elif source_lang != "English" and target_lang == "English":
95
- # Indic to English translation
96
  model_gpu = indic_en_model.to(device)
97
  tokenizer = indic_en_tokenizer
98
  direction = "indic_to_en"
99
  else:
100
  return "Please select English as either source or target language (not both)."
101
 
102
- # Format input properly for IndicTrans2
103
- formatted_input = format_input_for_indictrans2(
104
- input_text, src_code, tgt_code, direction
 
 
 
 
 
105
  )
106
 
107
- # Tokenize with proper settings
108
  inputs = tokenizer(
109
- formatted_input,
110
- return_tensors="pt",
111
- padding=True,
112
  truncation=True,
113
- max_length=256,
114
- return_token_type_ids=False
 
115
  ).to(device)
116
 
117
- # Remove any unwanted keys
118
- if 'token_type_ids' in inputs:
119
- del inputs['token_type_ids']
120
-
121
- # Set up generation parameters based on direction
122
- if direction == "en_to_indic":
123
- # Get target language token for decoder start
124
- tgt_lang_token = tokenizer.convert_tokens_to_ids(tgt_code)
125
- else:
126
- # For Indic to English, use English token
127
- tgt_lang_token = tokenizer.convert_tokens_to_ids("eng_Latn")
128
-
129
  # Generate translation
130
  with torch.no_grad():
131
  generated_tokens = model_gpu.generate(
132
- input_ids=inputs['input_ids'],
133
- attention_mask=inputs['attention_mask'],
134
- decoder_start_token_id=tgt_lang_token if tgt_lang_token != tokenizer.unk_token_id else None,
135
  max_length=max_length,
136
- min_length=1,
137
  num_beams=5,
138
  num_return_sequences=1,
139
  early_stopping=True,
140
- do_sample=False,
141
  pad_token_id=tokenizer.pad_token_id,
142
- eos_token_id=tokenizer.eos_token_id,
143
- use_cache=True
144
  )
145
 
146
- # Decode output
147
- translated_text = tokenizer.decode(
148
- generated_tokens[0],
149
  skip_special_tokens=True,
150
- clean_up_tokenization_spaces=True
151
  )
152
 
153
- # Clean up the output
154
- # Remove language tags and unwanted tokens
155
- cleaned_output = re.sub(r'<.*?>', '', translated_text)
156
- cleaned_output = cleaned_output.strip()
157
 
158
  # Move model back to CPU
159
  model_gpu.cpu()
160
  torch.cuda.empty_cache()
161
 
162
- return cleaned_output if cleaned_output else "Translation failed. Please try again."
163
 
164
  except Exception as e:
165
- # Clean up GPU memory in case of error
166
  if 'model_gpu' in locals():
167
  model_gpu.cpu()
168
  torch.cuda.empty_cache()
169
  return f"Error during translation: {str(e)}"
170
 
171
  # Create Gradio interface
172
- with gr.Blocks(title="IndicTrans2 Translator", theme=gr.themes.Soft()) as demo:
173
  gr.Markdown("""
174
  # 🇮🇳 IndicTrans2 - Official AI4Bharat Translator
175
 
176
  High-quality neural machine translation between English and 22 Indian languages.
 
177
 
178
  **Supported Languages**: Assamese, Bengali, Bodo, Dogri, Gujarati, Hindi, Kannada, Kashmiri,
179
  Konkani, Maithili, Malayalam, Manipuri, Marathi, Nepali, Odia, Punjabi, Sanskrit, Santali,
180
  Sindhi, Tamil, Telugu, Urdu.
181
-
182
- **Note**: Select English as either source OR target language (not both).
183
  """)
184
 
185
  with gr.Row():
@@ -197,8 +172,6 @@ with gr.Blocks(title="IndicTrans2 Translator", theme=gr.themes.Soft()) as demo:
197
  label="Source Language"
198
  )
199
 
200
- swap_btn = gr.Button("⇄", size="sm")
201
-
202
  target_lang = gr.Dropdown(
203
  choices=list(LANGUAGE_CODES.keys()),
204
  value="Hindi",
@@ -222,17 +195,17 @@ with gr.Blocks(title="IndicTrans2 Translator", theme=gr.themes.Soft()) as demo:
222
  interactive=False
223
  )
224
 
225
- clear_btn = gr.Button("Clear All", variant="secondary")
226
 
227
- # Examples that work with the corrected format
228
- gr.Markdown("### 💡 Example Translations:")
229
 
230
  examples = [
231
- ["Hello, how are you?", "English", "Hindi", 64],
232
- ["Good morning, everyone!", "English", "Bengali", 64],
233
- ["आपका नाम क्या है?", "Hindi", "English", 64],
234
- ["আপনি কেমন আছেন?", "Bengali", "English", 64],
235
- ["Technology is amazing.", "English", "Tamil", 96]
236
  ]
237
 
238
  gr.Examples(
@@ -243,9 +216,6 @@ with gr.Blocks(title="IndicTrans2 Translator", theme=gr.themes.Soft()) as demo:
243
  )
244
 
245
  # Event handlers
246
- def swap_languages(src, tgt):
247
- return tgt, src
248
-
249
  def clear_all():
250
  return "", ""
251
 
@@ -255,12 +225,6 @@ with gr.Blocks(title="IndicTrans2 Translator", theme=gr.themes.Soft()) as demo:
255
  outputs=output_text
256
  )
257
 
258
- swap_btn.click(
259
- swap_languages,
260
- inputs=[source_lang, target_lang],
261
- outputs=[source_lang, target_lang]
262
- )
263
-
264
  clear_btn.click(
265
  clear_all,
266
  outputs=[input_text, output_text]
 
2
  import gradio as gr
3
  import torch
4
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5
+ from IndicTransToolkit.processor import IndicProcessor
6
 
7
  # Model configurations
8
  INDIC_EN_MODEL = "ai4bharat/indictrans2-indic-en-1B"
9
  EN_INDIC_MODEL = "ai4bharat/indictrans2-en-indic-1B"
10
 
11
+ print("Loading IndicTrans2 models...")
12
+ # Load tokenizers
13
  indic_en_tokenizer = AutoTokenizer.from_pretrained(INDIC_EN_MODEL, trust_remote_code=True)
14
  en_indic_tokenizer = AutoTokenizer.from_pretrained(EN_INDIC_MODEL, trust_remote_code=True)
15
 
16
+ # Load models on CPU
17
  indic_en_model = AutoModelForSeq2SeqLM.from_pretrained(
18
  INDIC_EN_MODEL,
19
  trust_remote_code=True,
 
28
  device_map="cpu"
29
  )
30
 
31
+ # Initialize IndicProcessor (CRUCIAL for proper preprocessing)
32
+ ip = IndicProcessor(inference=True)
33
+
34
+ # Language mappings (exact codes from official documentation)
35
  LANGUAGE_CODES = {
36
  "Assamese": "asm_Beng",
37
  "Bengali": "ben_Beng",
 
61
  "English": "eng_Latn"
62
  }
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  @spaces.GPU(duration=120)
65
  def translate_text(input_text, source_lang, target_lang, max_length):
66
+ """Translate using IndicTrans2 with proper preprocessing"""
67
 
68
  if not input_text.strip():
69
  return "Please enter text to translate."
 
75
  src_code = LANGUAGE_CODES[source_lang]
76
  tgt_code = LANGUAGE_CODES[target_lang]
77
 
78
+ # Determine direction and select appropriate model/tokenizer
79
  if source_lang == "English" and target_lang != "English":
80
+ # English to Indic
81
  model_gpu = en_indic_model.to(device)
82
  tokenizer = en_indic_tokenizer
83
  direction = "en_to_indic"
84
  elif source_lang != "English" and target_lang == "English":
85
+ # Indic to English
86
  model_gpu = indic_en_model.to(device)
87
  tokenizer = indic_en_tokenizer
88
  direction = "indic_to_en"
89
  else:
90
  return "Please select English as either source or target language (not both)."
91
 
92
+ # CRUCIAL: Use IndicProcessor for proper preprocessing
93
+ input_sentences = [input_text.strip()]
94
+
95
+ # Preprocess using IndicProcessor (this handles the proper formatting)
96
+ batch = ip.preprocess_batch(
97
+ input_sentences,
98
+ src_lang=src_code,
99
+ tgt_lang=tgt_code,
100
  )
101
 
102
+ # Tokenize the preprocessed batch
103
  inputs = tokenizer(
104
+ batch,
 
 
105
  truncation=True,
106
+ padding="longest",
107
+ return_tensors="pt",
108
+ return_attention_mask=True,
109
  ).to(device)
110
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  # Generate translation
112
  with torch.no_grad():
113
  generated_tokens = model_gpu.generate(
114
+ **inputs,
115
+ use_cache=True,
116
+ min_length=0,
117
  max_length=max_length,
 
118
  num_beams=5,
119
  num_return_sequences=1,
120
  early_stopping=True,
 
121
  pad_token_id=tokenizer.pad_token_id,
122
+ eos_token_id=tokenizer.eos_token_id
 
123
  )
124
 
125
+ # Decode generated tokens
126
+ generated_tokens = tokenizer.batch_decode(
127
+ generated_tokens,
128
  skip_special_tokens=True,
129
+ clean_up_tokenization_spaces=True,
130
  )
131
 
132
+ # CRUCIAL: Postprocess using IndicProcessor
133
+ translations = ip.postprocess_batch(generated_tokens, lang=tgt_code)
 
 
134
 
135
  # Move model back to CPU
136
  model_gpu.cpu()
137
  torch.cuda.empty_cache()
138
 
139
+ return translations[0] if translations else "Translation failed."
140
 
141
  except Exception as e:
 
142
  if 'model_gpu' in locals():
143
  model_gpu.cpu()
144
  torch.cuda.empty_cache()
145
  return f"Error during translation: {str(e)}"
146
 
147
  # Create Gradio interface
148
+ with gr.Blocks(title="IndicTrans2 Official Translator", theme=gr.themes.Soft()) as demo:
149
  gr.Markdown("""
150
  # 🇮🇳 IndicTrans2 - Official AI4Bharat Translator
151
 
152
  High-quality neural machine translation between English and 22 Indian languages.
153
+ Uses official IndicTransToolkit for proper preprocessing.
154
 
155
  **Supported Languages**: Assamese, Bengali, Bodo, Dogri, Gujarati, Hindi, Kannada, Kashmiri,
156
  Konkani, Maithili, Malayalam, Manipuri, Marathi, Nepali, Odia, Punjabi, Sanskrit, Santali,
157
  Sindhi, Tamil, Telugu, Urdu.
 
 
158
  """)
159
 
160
  with gr.Row():
 
172
  label="Source Language"
173
  )
174
 
 
 
175
  target_lang = gr.Dropdown(
176
  choices=list(LANGUAGE_CODES.keys()),
177
  value="Hindi",
 
195
  interactive=False
196
  )
197
 
198
+ clear_btn = gr.Button("Clear", variant="secondary")
199
 
200
+ # Examples from official documentation
201
+ gr.Markdown("### 💡 Official Examples:")
202
 
203
  examples = [
204
+ ["When I was young, I used to go to the park every day.", "English", "Hindi", 128],
205
+ ["We watched a new movie last week, which was very inspiring.", "English", "Bengali", 128],
206
+ ["जब मैं छोटा था, मैं हर रोज़ पार्क जाता था।", "Hindi", "English", 128],
207
+ ["हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।", "Hindi", "English", 128],
208
+ ["Technology is changing our world rapidly.", "English", "Tamil", 128]
209
  ]
210
 
211
  gr.Examples(
 
216
  )
217
 
218
  # Event handlers
 
 
 
219
  def clear_all():
220
  return "", ""
221
 
 
225
  outputs=output_text
226
  )
227
 
 
 
 
 
 
 
228
  clear_btn.click(
229
  clear_all,
230
  outputs=[input_text, output_text]