marcosremar2 commited on
Commit
d478b16
·
1 Parent(s): b9d0632
Files changed (3) hide show
  1. __pycache__/app.cpython-313.pyc +0 -0
  2. app.py +195 -51
  3. requirements.txt +2 -0
__pycache__/app.cpython-313.pyc ADDED
Binary file (14 kB). View file
 
app.py CHANGED
@@ -1,26 +1,48 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import pipeline
4
- import os # os is imported but not used. Consider removing if not needed.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  # --- Model Configuration ---
7
  whisper_model_id = "openai/whisper-tiny"
8
- # Using gpt2 as a placeholder due to LLaMA-Omni2-0.5B's complex setup needs.
9
- # LLaMA-Omni2-0.5B (ICTNLP/LLaMA-Omni2-0.5B) is a speech-language model
10
- # requiring specific dependencies (e.g., CosyVoice) and often its own serving infrastructure.
11
- # It's not typically loaded via a simple transformers.pipeline for text generation alone.
12
- text_generation_model_id = "gpt2"
13
 
14
  # --- Device Configuration ---
15
  if torch.cuda.is_available():
16
  device_for_pipelines = 0 # Use the first GPU for Hugging Face pipelines
17
  torch_device = "cuda:0" # PyTorch device string
18
- # For models that support it and where precision is not critical, float16 can save memory/speed up.
19
- # However, Whisper models are often more robust with float32 for pipeline usage unless memory is very constrained.
20
- # GPT-2 also generally runs fine on float32 and doesn't strictly need float16 for basic use.
21
- dtype_for_pipelines = torch.float16 # or torch.float32 depending on model/GPU
22
  else:
23
- device_for_pipelines = -1 # Use CPU for Hugging Face pipelines
24
  torch_device = "cpu"
25
  dtype_for_pipelines = torch.float32
26
 
@@ -34,28 +56,78 @@ try:
34
  asr_pipeline_instance = pipeline(
35
  "automatic-speech-recognition",
36
  model=whisper_model_id,
37
- torch_dtype=dtype_for_pipelines, # Using specified dtype
38
  device=device_for_pipelines
39
  )
40
  print(f"ASR model ({whisper_model_id}) loaded successfully.")
41
  except Exception as e:
42
  print(f"Error loading ASR model ({whisper_model_id}): {e}")
43
- asr_pipeline_instance = None # Explicitly set to None on failure
44
 
45
- # --- Load Text Generation Pipeline ---
46
  text_gen_pipeline_instance = None
47
- try:
48
- print(f"Loading text generation model: {text_generation_model_id}...")
49
- text_gen_pipeline_instance = pipeline(
50
- "text-generation",
51
- model=text_generation_model_id,
52
- torch_dtype=dtype_for_pipelines, # Using specified dtype
53
- device=device_for_pipelines
54
- )
55
- print(f"Text generation model ({text_generation_model_id}) loaded successfully.")
56
- except Exception as e:
57
- print(f"Error loading text generation model ({text_generation_model_id}): {e}")
58
- text_gen_pipeline_instance = None # Explicitly set to None on failure
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  # --- Core Functions ---
61
  def transcribe_audio_input(audio_filepath):
@@ -65,24 +137,56 @@ def transcribe_audio_input(audio_filepath):
65
  return "No audio file provided for transcription.", ""
66
  try:
67
  print(f"Transcribing: {audio_filepath}")
68
- # Add chunk_length_s for handling longer audio files robustly
69
  result = asr_pipeline_instance(audio_filepath, chunk_length_s=30)
70
  transcribed_text = result["text"]
71
  print(f"Transcription: '{transcribed_text}'")
72
- return transcribed_text, transcribed_text # Return for UI and next step
73
  except Exception as e:
74
  print(f"Transcription error: {e}")
75
  return f"Error during transcription: {str(e)}", ""
76
 
77
  def generate_text_response(prompt_text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  if not text_gen_pipeline_instance:
79
- return f"Text generation model ({text_generation_model_id}) not available. Check logs."
80
  if not prompt_text or not prompt_text.strip():
81
  return "Prompt is empty. Please provide text for generation."
82
  try:
83
  print(f"Generating response for prompt (first 100 chars): '{prompt_text[:100]}...'")
84
- # max_new_tokens is generally preferred over max_length for more control
85
- generated_outputs = text_gen_pipeline_instance(prompt_text, max_new_tokens=100, num_return_sequences=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  response_text = generated_outputs[0]["generated_text"]
87
  print(f"Generated response: '{response_text}'")
88
  return response_text
@@ -102,22 +206,38 @@ def combined_pipeline_process(audio_filepath):
102
  error_msg_for_generation = "Cannot generate response: ASR model not loaded."
103
  return transcribed_text, error_msg_for_generation
104
 
105
- if not text_gen_pipeline_instance:
106
- return transcribed_text, f"Cannot generate response: Text generation model ({text_generation_model_id}) not loaded."
107
 
108
  final_response = generate_text_response(transcribed_text)
109
  return transcribed_text, final_response
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  # --- Gradio Interface Definition ---
112
- with gr.Blocks(theme=gr.themes.Soft(), title="Audio to Text Generation") as app_interface:
113
  gr.Markdown(
114
- """
115
  # Speech-to-Text and Text Generation Demo
116
 
117
- This application uses **OpenAI Whisper Tiny** for speech recognition and **GPT-2** (as a stand-in for more complex models like LLaMA-Omni2) for text generation.
118
- You can upload an audio file, have it transcribed, and then use that transcription as a prompt to generate further text.
 
 
119
 
120
- **Note on LLaMA-Omni2-0.5B:** The `ICTNLP/LLaMA-Omni2-0.5B` model is a sophisticated speech-language model designed for real-time spoken chat, generating both text and speech. It requires a specific setup environment (including its own speech synthesis like CosyVoice and potentially a dedicated serving mechanism). It's not plug-and-play with a simple `transformers.pipeline` in the same way standard ASR or text-only LLMs are. Therefore, GPT-2 is used here to demonstrate the Gradio app structure.
121
  """
122
  )
123
 
@@ -126,7 +246,8 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Audio to Text Generation") as app_
126
  input_audio_pipeline = gr.Audio(type="filepath", label="Upload Your Audio File (.wav, .mp3)")
127
  submit_button_full = gr.Button("Run Full Process", variant="primary")
128
  output_transcription_pipeline = gr.Textbox(label="Transcribed Text (from Whisper)", lines=5)
129
- output_generation_pipeline = gr.Textbox(label=f"Generated Text (from {text_generation_model_id})", lines=7)
 
130
 
131
  submit_button_full.click(
132
  fn=combined_pipeline_process,
@@ -142,7 +263,6 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Audio to Text Generation") as app_
142
 
143
  def asr_only_ui(audio_file):
144
  if audio_file is None: return "Please upload an audio file."
145
- # The transcribe_audio_input returns two values; we only need the first for this UI.
146
  transcription, _ = transcribe_audio_input(audio_file)
147
  return transcription
148
 
@@ -152,8 +272,9 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Audio to Text Generation") as app_
152
  outputs=[output_transcription_asr]
153
  )
154
 
155
- with gr.Tab(f"Test Text Generation ({text_generation_model_id})"):
156
- gr.Markdown(f"### Generate text from a prompt using {text_generation_model_id}.")
 
157
  input_text_prompt_gen = gr.Textbox(label="Your Text Prompt", placeholder="Enter text here...", lines=5)
158
  submit_button_gen = gr.Button("Generate Text", variant="secondary")
159
  output_generation_gen = gr.Textbox(label="Generated Text Result", lines=10)
@@ -167,15 +288,38 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Audio to Text Generation") as app_
167
  gr.Markdown("--- ")
168
  gr.Markdown("### Model Loading Status (at application start):")
169
  asr_load_status = "Successfully Loaded" if asr_pipeline_instance else "Failed to Load (check console logs)"
170
- text_gen_load_status = "Successfully Loaded" if text_gen_pipeline_instance else "Failed to Load (check console logs)"
171
  gr.Markdown(f"* **Whisper Model ({whisper_model_id}):** `{asr_load_status}`")
172
- gr.Markdown(f"* **Text Generation Model ({text_generation_model_id}):** `{text_gen_load_status}`")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  # --- Launch the Gradio App ---
175
  if __name__ == "__main__":
176
- print("Attempting to launch Gradio application...")
177
- # share=True is good for Hugging Face Spaces. For local, it's optional.
178
- # For persistent public link when running locally (requires internet & can have security implications):
179
- # app_interface.launch(share=True)
180
- app_interface.launch()
181
- print("Gradio application launched. Check your browser or console for the URL.")
 
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
4
+ import os
5
+ import warnings
6
+ import importlib
7
+ import sys
8
+ import subprocess
9
+
10
+ # Check if we can import LLaMA-Omni2's modules
11
+ try_native_modules = True
12
+ native_llama_omni_available = False
13
+ native_modules_error = None
14
+
15
+ if try_native_modules:
16
+ try:
17
+ # Try importing LLaMA-Omni2 specific modules using subprocess to avoid crashing if imports fail
18
+ print("Checking for LLaMA-Omni2 native modules...")
19
+ module_check_result = subprocess.run(
20
+ [sys.executable, "-c", "import llama_omni2; print('LLaMA-Omni2 modules found!')"],
21
+ capture_output=True,
22
+ text=True
23
+ )
24
+ if "LLaMA-Omni2 modules found!" in module_check_result.stdout:
25
+ print("LLaMA-Omni2 native modules are available!")
26
+ native_llama_omni_available = True
27
+ else:
28
+ print(f"LLaMA-Omni2 native modules not found: {module_check_result.stderr}")
29
+ native_modules_error = module_check_result.stderr
30
+ except Exception as e:
31
+ print(f"Error checking for LLaMA-Omni2 native modules: {e}")
32
+ native_modules_error = str(e)
33
 
34
  # --- Model Configuration ---
35
  whisper_model_id = "openai/whisper-tiny"
36
+ llama_omni_model_id = "ICTNLP/LLaMA-Omni2-0.5B" # Primary model we'll try to load
37
+ fallback_model_id = "gpt2" # Fallback if LLaMA-Omni2 fails to load
 
 
 
38
 
39
  # --- Device Configuration ---
40
  if torch.cuda.is_available():
41
  device_for_pipelines = 0 # Use the first GPU for Hugging Face pipelines
42
  torch_device = "cuda:0" # PyTorch device string
43
+ dtype_for_pipelines = torch.float16
 
 
 
44
  else:
45
+ device_for_pipelines = -1 # Use CPU for Hugging Face pipelines
46
  torch_device = "cpu"
47
  dtype_for_pipelines = torch.float32
48
 
 
56
  asr_pipeline_instance = pipeline(
57
  "automatic-speech-recognition",
58
  model=whisper_model_id,
59
+ torch_dtype=dtype_for_pipelines,
60
  device=device_for_pipelines
61
  )
62
  print(f"ASR model ({whisper_model_id}) loaded successfully.")
63
  except Exception as e:
64
  print(f"Error loading ASR model ({whisper_model_id}): {e}")
65
+ asr_pipeline_instance = None
66
 
67
+ # --- Load Text Generation Model ---
68
  text_gen_pipeline_instance = None
69
+ text_generation_model_id = None # Will be set to the model that successfully loads
70
+ llama_omni_native_module = None # Will hold the native LLaMA-Omni2 module if loaded
71
+
72
+ # Try native LLaMA-Omni2 module first if available
73
+ if native_llama_omni_available:
74
+ try:
75
+ print("Attempting to load LLaMA-Omni2 using native modules...")
76
+ # Import the required modules
77
+ import llama_omni2
78
+ from llama_omni2.model import Model as LLamaOmniModel
79
+
80
+ # Load the model
81
+ llama_omni_native_module = LLamaOmniModel.from_pretrained(llama_omni_model_id)
82
+ text_generation_model_id = llama_omni_model_id
83
+ print(f"LLaMA-Omni2 native module loaded successfully: {type(llama_omni_native_module)}")
84
+ except Exception as e:
85
+ print(f"Error loading native LLaMA-Omni2 module: {e}")
86
+ llama_omni_native_module = None
87
+
88
+ # If native module failed, try loading using transformers
89
+ if llama_omni_native_module is None and text_generation_model_id is None:
90
+ try:
91
+ print(f"Attempting to load LLaMA-Omni2 using transformers: {llama_omni_model_id}...")
92
+ # LLaMA models often require specific loading configurations
93
+ tokenizer = AutoTokenizer.from_pretrained(llama_omni_model_id, trust_remote_code=True)
94
+ model = AutoModelForCausalLM.from_pretrained(
95
+ llama_omni_model_id,
96
+ torch_dtype=dtype_for_pipelines,
97
+ trust_remote_code=True,
98
+ device_map="auto" if torch.cuda.is_available() else None
99
+ )
100
+
101
+ text_gen_pipeline_instance = pipeline(
102
+ "text-generation",
103
+ model=model,
104
+ tokenizer=tokenizer,
105
+ torch_dtype=dtype_for_pipelines,
106
+ device=device_for_pipelines if not torch.cuda.is_available() else None
107
+ )
108
+ text_generation_model_id = llama_omni_model_id
109
+ print(f"LLaMA-Omni2 model ({llama_omni_model_id}) loaded successfully via transformers.")
110
+
111
+ except Exception as e:
112
+ warnings.warn(f"Error loading LLaMA-Omni2 model: {e}\nFalling back to {fallback_model_id}")
113
+ print(f"Error loading LLaMA-Omni2 model via transformers: {e}")
114
+ print(f"Falling back to {fallback_model_id}")
115
+
116
+ # Fall back to GPT-2 if LLaMA-Omni2 fails to load both ways
117
+ if text_generation_model_id is None:
118
+ try:
119
+ print(f"Loading fallback text generation model: {fallback_model_id}...")
120
+ text_gen_pipeline_instance = pipeline(
121
+ "text-generation",
122
+ model=fallback_model_id,
123
+ torch_dtype=dtype_for_pipelines,
124
+ device=device_for_pipelines
125
+ )
126
+ text_generation_model_id = fallback_model_id
127
+ print(f"Fallback model ({fallback_model_id}) loaded successfully.")
128
+ except Exception as e:
129
+ print(f"Error loading fallback model ({fallback_model_id}): {e}")
130
+ text_gen_pipeline_instance = None
131
 
132
  # --- Core Functions ---
133
  def transcribe_audio_input(audio_filepath):
 
137
  return "No audio file provided for transcription.", ""
138
  try:
139
  print(f"Transcribing: {audio_filepath}")
 
140
  result = asr_pipeline_instance(audio_filepath, chunk_length_s=30)
141
  transcribed_text = result["text"]
142
  print(f"Transcription: '{transcribed_text}'")
143
+ return transcribed_text, transcribed_text
144
  except Exception as e:
145
  print(f"Transcription error: {e}")
146
  return f"Error during transcription: {str(e)}", ""
147
 
148
  def generate_text_response(prompt_text):
149
+ # If we have a native LLaMA-Omni2 module, use it
150
+ if llama_omni_native_module is not None:
151
+ if not prompt_text or not prompt_text.strip():
152
+ return "Prompt is empty. Please provide text for generation."
153
+ try:
154
+ print(f"Generating response with native LLaMA-Omni2 for prompt: '{prompt_text[:100]}...'")
155
+ # Using the native module's interface for text generation
156
+ response = llama_omni_native_module.generate(prompt_text, max_length=150)
157
+ print(f"Generated response: '{response}'")
158
+ return response
159
+ except Exception as e:
160
+ print(f"Error using native LLaMA-Omni2 generation: {e}")
161
+ return f"Error during native LLaMA-Omni2 text generation: {str(e)}"
162
+
163
+ # Otherwise use the transformers pipeline
164
  if not text_gen_pipeline_instance:
165
+ return f"Text generation model not available. Check logs."
166
  if not prompt_text or not prompt_text.strip():
167
  return "Prompt is empty. Please provide text for generation."
168
  try:
169
  print(f"Generating response for prompt (first 100 chars): '{prompt_text[:100]}...'")
170
+
171
+ # Different generation parameters based on model
172
+ if text_generation_model_id == llama_omni_model_id:
173
+ # Parameters optimized for LLaMA-Omni2
174
+ generated_outputs = text_gen_pipeline_instance(
175
+ prompt_text,
176
+ max_new_tokens=150,
177
+ do_sample=True,
178
+ temperature=0.7,
179
+ top_p=0.9,
180
+ num_return_sequences=1
181
+ )
182
+ else:
183
+ # Parameters for fallback model
184
+ generated_outputs = text_gen_pipeline_instance(
185
+ prompt_text,
186
+ max_new_tokens=100,
187
+ num_return_sequences=1
188
+ )
189
+
190
  response_text = generated_outputs[0]["generated_text"]
191
  print(f"Generated response: '{response_text}'")
192
  return response_text
 
206
  error_msg_for_generation = "Cannot generate response: ASR model not loaded."
207
  return transcribed_text, error_msg_for_generation
208
 
209
+ if not text_gen_pipeline_instance and llama_omni_native_module is None:
210
+ return transcribed_text, f"Cannot generate response: No text generation model available."
211
 
212
  final_response = generate_text_response(transcribed_text)
213
  return transcribed_text, final_response
214
 
215
+ # Determine model status for UI
216
+ if llama_omni_native_module is not None:
217
+ llama_model_status = "Native LLaMA-Omni2 module loaded successfully"
218
+ using_model = "LLaMA-Omni2-0.5B (native modules)"
219
+ elif text_generation_model_id == llama_omni_model_id:
220
+ llama_model_status = "LLaMA-Omni2 loaded via transformers"
221
+ using_model = "LLaMA-Omni2-0.5B (via transformers)"
222
+ elif text_generation_model_id == fallback_model_id:
223
+ llama_model_status = "Failed to load - Using GPT-2 as fallback"
224
+ using_model = "GPT-2 (fallback model)"
225
+ else:
226
+ llama_model_status = "Failed to load any text generation model"
227
+ using_model = "No model available"
228
+
229
  # --- Gradio Interface Definition ---
230
+ with gr.Blocks(theme=gr.themes.Soft(), title="Whisper + LLaMA-Omni2 Demo") as app_interface:
231
  gr.Markdown(
232
+ f"""
233
  # Speech-to-Text and Text Generation Demo
234
 
235
+ This application uses **OpenAI Whisper Tiny** for speech recognition and attempts to use **LLaMA-Omni2-0.5B** for text generation.
236
+ If LLaMA-Omni2 cannot be loaded, it falls back to GPT-2.
237
+
238
+ **Currently using:** {using_model}
239
 
240
+ Upload an audio file to transcribe it. The transcribed text will then be used as a prompt for the text generation model.
241
  """
242
  )
243
 
 
246
  input_audio_pipeline = gr.Audio(type="filepath", label="Upload Your Audio File (.wav, .mp3)")
247
  submit_button_full = gr.Button("Run Full Process", variant="primary")
248
  output_transcription_pipeline = gr.Textbox(label="Transcribed Text (from Whisper)", lines=5)
249
+ model_label = f"Generated Text (from {using_model})"
250
+ output_generation_pipeline = gr.Textbox(label=model_label, lines=7)
251
 
252
  submit_button_full.click(
253
  fn=combined_pipeline_process,
 
263
 
264
  def asr_only_ui(audio_file):
265
  if audio_file is None: return "Please upload an audio file."
 
266
  transcription, _ = transcribe_audio_input(audio_file)
267
  return transcription
268
 
 
272
  outputs=[output_transcription_asr]
273
  )
274
 
275
+ with gr.Tab(f"Test Text Generation"):
276
+ model_name_gen = using_model
277
+ gr.Markdown(f"### Generate text from a prompt using {model_name_gen}.")
278
  input_text_prompt_gen = gr.Textbox(label="Your Text Prompt", placeholder="Enter text here...", lines=5)
279
  submit_button_gen = gr.Button("Generate Text", variant="secondary")
280
  output_generation_gen = gr.Textbox(label="Generated Text Result", lines=10)
 
288
  gr.Markdown("--- ")
289
  gr.Markdown("### Model Loading Status (at application start):")
290
  asr_load_status = "Successfully Loaded" if asr_pipeline_instance else "Failed to Load (check console logs)"
291
+
292
  gr.Markdown(f"* **Whisper Model ({whisper_model_id}):** `{asr_load_status}`")
293
+ gr.Markdown(f"* **LLaMA-Omni2 Model ({llama_omni_model_id}):** `{llama_model_status}`")
294
+
295
+ if native_llama_omni_available:
296
+ gr.Markdown("* **LLaMA-Omni2 Native Modules:** `Available`")
297
+ else:
298
+ native_error = f": {native_modules_error}" if native_modules_error else ""
299
+ gr.Markdown(f"* **LLaMA-Omni2 Native Modules:** `Not Available{native_error}`")
300
+
301
+ if using_model.startswith("GPT-2"):
302
+ gr.Markdown(
303
+ """
304
+ **Note about LLaMA-Omni2-0.5B:** This model has complex dependencies and requires a specific setup environment.
305
+ The system attempted to load it but fell back to GPT-2. For full functionality with LLaMA-Omni2, you should:
306
+
307
+ 1. Clone the [LLaMA-Omni2 repository](https://github.com/ictnlp/LLaMA-Omni2)
308
+ 2. Install the required dependencies including CosyVoice 2
309
+ 3. Download the Whisper-large-v3 model and flow-matching model and vocoder of CosyVoice 2
310
+ 4. Set up the controller, model worker, and web server as described in the repository
311
+
312
+ Note that LLaMA-Omni2 is designed for generating both text and speech responses simultaneously.
313
+ For the full experience with speech synthesis, you need the complete setup.
314
+ """
315
+ )
316
 
317
  # --- Launch the Gradio App ---
318
  if __name__ == "__main__":
319
+ print("Launching Gradio demo...")
320
+ try:
321
+ app_interface.launch(share=True)
322
+ except Exception as e:
323
+ print(f"Error launching with share=True: {e}")
324
+ print("Trying to launch without sharing...")
325
+ app_interface.launch()
requirements.txt CHANGED
@@ -17,5 +17,7 @@ pydub
17
  ffmpeg-python
18
  huggingface_hub # For downloading models from HF Hub
19
  soundfile # To handle audio files if not using gr.Audio input directly for some reason
 
 
20
 
21
  # fairseq and flash-attn are removed, expected to be handled by LLaMA-Omni2's setup via `pip install -e .` in Dockerfile
 
17
  ffmpeg-python
18
  huggingface_hub # For downloading models from HF Hub
19
  soundfile # To handle audio files if not using gr.Audio input directly for some reason
20
+ safetensors
21
+ ai2-olmo # In case LLaMA-Omni2 uses olmo under the hood for the LLM part
22
 
23
  # fairseq and flash-attn are removed, expected to be handled by LLaMA-Omni2's setup via `pip install -e .` in Dockerfile