yo
Browse files
app.py
CHANGED
@@ -15,51 +15,38 @@ logging.basicConfig(
|
|
15 |
level=logging.DEBUG,
|
16 |
format='%(asctime)s - %(levelname)s - %(message)s'
|
17 |
)
|
18 |
-
|
19 |
-
checkpoint_dir = "facebook/final_m2m100"
|
20 |
-
# Initialize translation model
|
21 |
try:
|
22 |
-
|
|
|
|
|
23 |
tokenizer = M2M100Tokenizer.from_pretrained(checkpoint_dir)
|
24 |
-
logging.info("M2M100 tokenizer loaded successfully")
|
25 |
-
|
26 |
model_m2m = M2M100ForConditionalGeneration.from_pretrained(checkpoint_dir)
|
27 |
-
logging.info("M2M100 model loaded successfully")
|
28 |
-
|
29 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
-
logging.info(f"Using device: {device}")
|
31 |
-
|
32 |
-
model_m2m.to(device)
|
33 |
-
m2m_available = True
|
34 |
-
logging.info("M2M100 model ready for use")
|
35 |
except Exception as e:
|
36 |
-
logging.error(f"Error loading M2M100 model: {e}")
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
# Initialize ASR model
|
41 |
model_id = "bishaltwr/wav2vec2-large-mms-1b-nepali"
|
42 |
-
|
43 |
-
|
44 |
-
model_asr = Wav2Vec2ForCTC.from_pretrained(model_id, ignore_mismatched_sizes=True)
|
45 |
-
asr_available = True
|
46 |
-
except Exception as e:
|
47 |
-
logging.error(f"Error loading ASR model: {e}")
|
48 |
-
asr_available = False
|
49 |
|
50 |
# Initialize X-Transformer model
|
51 |
-
|
52 |
-
from inference import translate as xtranslate
|
53 |
-
xtransformer_available = True
|
54 |
-
except Exception as e:
|
55 |
-
logging.error(f"Error loading XTransformer model: {e}")
|
56 |
-
xtransformer_available = False
|
57 |
|
58 |
def m2m_translate(text, source_lang, target_lang):
|
59 |
"""Translation using M2M100 model"""
|
60 |
-
if not m2m_available:
|
61 |
-
return "M2M100 model not available"
|
62 |
-
|
63 |
tokenizer.src_lang = source_lang
|
64 |
inputs = tokenizer(text, return_tensors="pt").to(device)
|
65 |
translated_tokens = model_m2m.generate(
|
@@ -71,9 +58,6 @@ def m2m_translate(text, source_lang, target_lang):
|
|
71 |
|
72 |
def transcribe_audio(audio_path, language="npi"):
|
73 |
"""Transcribe audio using ASR model"""
|
74 |
-
if not asr_available:
|
75 |
-
return "ASR model not available"
|
76 |
-
|
77 |
import librosa
|
78 |
audio, sr = librosa.load(audio_path, sr=16000)
|
79 |
processor.tokenizer.set_target_lang(language)
|
@@ -123,69 +107,55 @@ def translate_text(text, model_choice, source_lang=None, target_lang=None):
|
|
123 |
target_lang = "ne" if source_lang == "en" else "en"
|
124 |
|
125 |
# Choose the translation model
|
126 |
-
if model_choice == "XTransformer"
|
127 |
return xtranslate(text)
|
128 |
-
elif model_choice == "M2M100"
|
129 |
return m2m_translate(text, source_lang=source_lang, target_lang=target_lang)
|
130 |
else:
|
131 |
return "Selected model is not available"
|
132 |
|
133 |
# Set up the Gradio interface
|
134 |
with gr.Blocks(title="Nepali-English Translator") as demo:
|
135 |
-
gr.Markdown("# Nepali-English
|
136 |
gr.Markdown("Translate between Nepali and English, transcribe audio, and convert text to speech.")
|
137 |
-
|
138 |
-
|
139 |
-
with gr.
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
with gr.Column():
|
170 |
-
translation_output = gr.Textbox(label="Translation Output", lines=5)
|
171 |
-
tts_button = gr.Button("Convert to Speech")
|
172 |
-
audio_output = gr.Audio(label="Audio Output")
|
173 |
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
audio_input = gr.Audio(label="Upload or Record Audio", type="filepath")
|
178 |
-
asr_language = gr.Radio(
|
179 |
-
choices=["eng", "npi"],
|
180 |
-
value="npi",
|
181 |
-
label="Speech Language"
|
182 |
-
)
|
183 |
-
transcribe_button = gr.Button("Transcribe")
|
184 |
-
transcription_output = gr.Textbox(label="Transcription Output", lines=3)
|
185 |
|
186 |
# Define event handlers
|
187 |
def process_translation(text, model, src_lang, tgt_lang):
|
188 |
-
logging.info(f"Processing translation: text={text}, model={model}, src_lang={src_lang}, tgt_lang={tgt_lang}")
|
189 |
if src_lang == "Auto-detect":
|
190 |
src_lang = None
|
191 |
if tgt_lang == "Auto-select":
|
@@ -193,19 +163,23 @@ with gr.Blocks(title="Nepali-English Translator") as demo:
|
|
193 |
return translate_text(text, model, src_lang, tgt_lang)
|
194 |
|
195 |
def process_tts(text):
|
196 |
-
logging.info(f"Processing TTS: text={text}")
|
197 |
return text_to_speech(text)
|
198 |
|
199 |
def process_transcription(audio_path, language):
|
200 |
-
logging.info(f"Processing transcription: audio_path={audio_path}, language={language}")
|
201 |
if not audio_path:
|
202 |
return "Please upload or record audio"
|
203 |
return transcribe_audio(audio_path, language)
|
204 |
|
205 |
# Connect the components
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
translate_button.click(
|
207 |
process_translation,
|
208 |
-
inputs=[
|
209 |
outputs=translation_output
|
210 |
)
|
211 |
|
@@ -214,40 +188,7 @@ with gr.Blocks(title="Nepali-English Translator") as demo:
|
|
214 |
inputs=translation_output,
|
215 |
outputs=audio_output
|
216 |
)
|
217 |
-
|
218 |
-
transcribe_button.click(
|
219 |
-
process_transcription,
|
220 |
-
inputs=[audio_input, asr_language],
|
221 |
-
outputs=transcription_output
|
222 |
-
)
|
223 |
-
|
224 |
-
# Explicitly define API endpoints
|
225 |
-
process_translation_api = gr.Interface(
|
226 |
-
fn=process_translation,
|
227 |
-
inputs=[gr.Textbox(label="text"), gr.Radio(label="model"), gr.Dropdown(label="src_lang"), gr.Dropdown(label="tgt_lang")],
|
228 |
-
outputs=gr.Textbox(label="translation_output"),
|
229 |
-
api_name="process_translation"
|
230 |
-
)
|
231 |
-
|
232 |
-
process_tts_api = gr.Interface(
|
233 |
-
fn=process_tts,
|
234 |
-
inputs=gr.Textbox(label="text"),
|
235 |
-
outputs=gr.Audio(label="audio_output"),
|
236 |
-
api_name="process_tts"
|
237 |
-
)
|
238 |
-
|
239 |
-
process_transcription_api = gr.Interface(
|
240 |
-
fn=process_transcription,
|
241 |
-
inputs=[gr.Audio(label="audio_path"), gr.Radio(label="language")],
|
242 |
-
outputs=gr.Textbox(label="transcription_output"),
|
243 |
-
api_name="process_transcription"
|
244 |
-
)
|
245 |
-
|
246 |
-
# Add API endpoints to the app
|
247 |
-
process_translation_api.render()
|
248 |
-
process_tts_api.render()
|
249 |
-
process_transcription_api.render()
|
250 |
|
251 |
# Launch the app
|
252 |
if __name__ == "__main__":
|
253 |
-
demo.launch()
|
|
|
15 |
level=logging.DEBUG,
|
16 |
format='%(asctime)s - %(levelname)s - %(message)s'
|
17 |
)
|
18 |
+
# Update the model loading section
|
|
|
|
|
19 |
try:
|
20 |
+
# Try to load custom model
|
21 |
+
checkpoint_dir = "bishaltwr/final_m2m100"
|
22 |
+
logging.info(f"Attempting to load custom M2M100 from {checkpoint_dir}")
|
23 |
tokenizer = M2M100Tokenizer.from_pretrained(checkpoint_dir)
|
|
|
|
|
24 |
model_m2m = M2M100ForConditionalGeneration.from_pretrained(checkpoint_dir)
|
25 |
+
logging.info("Custom M2M100 model loaded successfully")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
except Exception as e:
|
27 |
+
logging.error(f"Error loading custom M2M100 model: {e}")
|
28 |
+
# Fall back to official model
|
29 |
+
checkpoint_dir = "facebook/m2m100_418M"
|
30 |
+
logging.info(f"Attempting to load official M2M100 from {checkpoint_dir}")
|
31 |
+
tokenizer = M2M100Tokenizer.from_pretrained(checkpoint_dir)
|
32 |
+
model_m2m = M2M100ForConditionalGeneration.from_pretrained(checkpoint_dir)
|
33 |
+
logging.info("Official M2M100 model loaded successfully")
|
34 |
+
|
35 |
+
# Set device after model loading
|
36 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
37 |
+
logging.info(f"Using device: {device}")
|
38 |
+
model_m2m.to(device)
|
39 |
|
40 |
# Initialize ASR model
|
41 |
model_id = "bishaltwr/wav2vec2-large-mms-1b-nepali"
|
42 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
43 |
+
model_asr = Wav2Vec2ForCTC.from_pretrained(model_id, ignore_mismatched_sizes=True)
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
# Initialize X-Transformer model
|
46 |
+
from inference import translate as xtranslate
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
def m2m_translate(text, source_lang, target_lang):
|
49 |
"""Translation using M2M100 model"""
|
|
|
|
|
|
|
50 |
tokenizer.src_lang = source_lang
|
51 |
inputs = tokenizer(text, return_tensors="pt").to(device)
|
52 |
translated_tokens = model_m2m.generate(
|
|
|
58 |
|
59 |
def transcribe_audio(audio_path, language="npi"):
|
60 |
"""Transcribe audio using ASR model"""
|
|
|
|
|
|
|
61 |
import librosa
|
62 |
audio, sr = librosa.load(audio_path, sr=16000)
|
63 |
processor.tokenizer.set_target_lang(language)
|
|
|
107 |
target_lang = "ne" if source_lang == "en" else "en"
|
108 |
|
109 |
# Choose the translation model
|
110 |
+
if model_choice == "XTransformer":
|
111 |
return xtranslate(text)
|
112 |
+
elif model_choice == "M2M100":
|
113 |
return m2m_translate(text, source_lang=source_lang, target_lang=target_lang)
|
114 |
else:
|
115 |
return "Selected model is not available"
|
116 |
|
117 |
# Set up the Gradio interface
|
118 |
with gr.Blocks(title="Nepali-English Translator") as demo:
|
119 |
+
gr.Markdown("# Nepali-English Translator")
|
120 |
gr.Markdown("Translate between Nepali and English, transcribe audio, and convert text to speech.")
|
121 |
+
gr.Markdown("Aakash Budhathoki, Apekshya Subedi, Bishal Tiwari, Kebin Malla. - Kantipur Engineering College.")
|
122 |
+
|
123 |
+
with gr.Column():
|
124 |
+
gr.Markdown("### Speech to Text")
|
125 |
+
audio_input = gr.Audio(label="Upload or Record Audio", type="filepath")
|
126 |
+
asr_language = gr.Radio(
|
127 |
+
choices=["eng", "npi"],
|
128 |
+
value="npi",
|
129 |
+
label="Speech Language"
|
130 |
+
)
|
131 |
+
transcribe_button = gr.Button("Transcribe")
|
132 |
+
transcription_output = gr.Textbox(label="Transcription Output", lines=3)
|
133 |
+
|
134 |
+
gr.Markdown("### Text Translation")
|
135 |
+
model_choice = gr.Dropdown(
|
136 |
+
choices=["XTransformer", "M2M100"],
|
137 |
+
value="M2M100",
|
138 |
+
label="Translation Model"
|
139 |
+
)
|
140 |
+
source_lang = gr.Dropdown(
|
141 |
+
choices=["Auto-detect", "en", "ne"],
|
142 |
+
value="Auto-detect",
|
143 |
+
label="Source Language"
|
144 |
+
)
|
145 |
+
target_lang = gr.Dropdown(
|
146 |
+
choices=["Auto-select", "en", "ne"],
|
147 |
+
value="Auto-select",
|
148 |
+
label="Target Language"
|
149 |
+
)
|
150 |
+
translate_button = gr.Button("Translate")
|
151 |
+
translation_output = gr.Textbox(label="Translation Output", lines=5)
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
+
gr.Markdown("### Text to Speech")
|
154 |
+
tts_button = gr.Button("Convert to Speech")
|
155 |
+
audio_output = gr.Audio(label="Audio Output")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
# Define event handlers
|
158 |
def process_translation(text, model, src_lang, tgt_lang):
|
|
|
159 |
if src_lang == "Auto-detect":
|
160 |
src_lang = None
|
161 |
if tgt_lang == "Auto-select":
|
|
|
163 |
return translate_text(text, model, src_lang, tgt_lang)
|
164 |
|
165 |
def process_tts(text):
|
|
|
166 |
return text_to_speech(text)
|
167 |
|
168 |
def process_transcription(audio_path, language):
|
|
|
169 |
if not audio_path:
|
170 |
return "Please upload or record audio"
|
171 |
return transcribe_audio(audio_path, language)
|
172 |
|
173 |
# Connect the components
|
174 |
+
transcribe_button.click(
|
175 |
+
process_transcription,
|
176 |
+
inputs=[audio_input, asr_language],
|
177 |
+
outputs=transcription_output
|
178 |
+
)
|
179 |
+
|
180 |
translate_button.click(
|
181 |
process_translation,
|
182 |
+
inputs=[transcription_output, model_choice, source_lang, target_lang],
|
183 |
outputs=translation_output
|
184 |
)
|
185 |
|
|
|
188 |
inputs=translation_output,
|
189 |
outputs=audio_output
|
190 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
# Launch the app
|
193 |
if __name__ == "__main__":
|
194 |
+
demo.launch()
|