marcosremar2 commited on
Commit
b9d0632
·
1 Parent(s): 87f6bd6
Files changed (2) hide show
  1. app.py +163 -232
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,250 +1,181 @@
1
- import os
2
- import sys
3
  import gradio as gr
4
- import whisper
5
- from huggingface_hub import snapshot_download
6
  import torch
7
- import subprocess
8
- import transformers
9
-
10
- # --- Aggressively update/install transformers and huggingface_hub BEFORE importing them ---
11
- print('Attempting to upgrade pip, transformers, and huggingface_hub...')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  try:
13
- print('Upgrading pip...')
14
- subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-U', 'pip'])
15
- print('Upgrading transformers and huggingface_hub...')
16
- subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-U', 'transformers', 'huggingface_hub'])
17
- print('Attempting to install transformers from main branch for latest features...')
18
- subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'git+https://github.com/huggingface/transformers.git'])
19
- print('Pip, Transformers, and huggingface_hub update/install process completed.')
20
- except subprocess.CalledProcessError as e:
21
- print(f'ERROR: Failed to upgrade/install packages: {e}')
22
- print('Continuing with potentially older versions. This might lead to model loading issues.')
23
  except Exception as e:
24
- print(f'An unexpected error occurred during package upgrades: {e}')
 
25
 
26
- # --- Now, import from transformers ---
 
27
  try:
28
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
29
- print('Successfully imported AutoModelForCausalLM, AutoTokenizer, AutoConfig from transformers.')
30
- except ImportError as e:
31
- print(f'CRITICAL ERROR: Failed to import from transformers after attempting upgrades: {e}')
32
- print('The application might not work correctly. Please check the environment and dependencies.')
33
- # As a last resort, define dummy classes if import fails, so the rest of the script doesn't crash immediately
34
- class AutoModelForCausalLM: pass
35
- class AutoTokenizer: pass
36
- class AutoConfig: pass
37
  except Exception as e:
38
- print(f'An unexpected error occurred during transformers import: {e}')
39
-
40
- # --- Configuration ---
41
- WHISPER_MODEL_SIZE = 'small' # Using smallest model for faster processing in testing
42
- SPEECH_ENCODER_PATH = 'models/speech_encoder'
43
- MODEL_NAME = 'LLaMA-Omni2-0.5B'
44
- MODEL_PATH = f'models/{MODEL_NAME}'
45
- HF_REPO = f'ICTNLP/{MODEL_NAME}'
46
-
47
- # --- Print diagnostics ---
48
- print('===== Application Startup =====')
49
- print('Python:', sys.version)
50
- print('Torch version:', torch.__version__)
51
- print(f'CUDA available: {torch.cuda.is_available()}')
52
- if torch.cuda.is_available():
53
- print(f'CUDA device: {torch.cuda.get_device_name(0)}')
54
- print(f'CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')
55
-
56
- # --- Main models ---
57
- whisper_model = None
58
- llama_model = None
59
- tokenizer = None
60
-
61
- def load_whisper_model():
62
- '''Load Whisper model for speech recognition'''
63
- global whisper_model
64
- print(f'Loading Whisper {WHISPER_MODEL_SIZE} model...')
65
-
66
- # Create directory if it doesn't exist
67
- os.makedirs(SPEECH_ENCODER_PATH, exist_ok=True)
68
-
69
- # Load the model (will download if not present)
70
- whisper_model = whisper.load_model(WHISPER_MODEL_SIZE, download_root=SPEECH_ENCODER_PATH)
71
- print(f'Whisper {WHISPER_MODEL_SIZE} model loaded successfully!')
72
- return whisper_model
73
-
74
- def load_llama_model():
75
- '''Load LLaMA-Omni2 model'''
76
- global llama_model, tokenizer
77
- print(f'Attempting to load LLaMA-Omni2 model: {HF_REPO}')
78
-
79
- # Ensure local model directory exists for downloads
80
- os.makedirs(MODEL_PATH, exist_ok=True)
81
-
82
- # Download model files if they aren't already present locally
83
- # Check for a common file like config.json to decide if download is needed
84
- if not os.path.exists(os.path.join(MODEL_PATH, 'config.json')):
85
- print(f'Local model files not found. Downloading from Hugging Face Hub: {HF_REPO} to {MODEL_PATH}')
86
- try:
87
- snapshot_download(
88
- repo_id=HF_REPO,
89
- local_dir=MODEL_PATH,
90
- local_dir_use_symlinks=False,
91
- resume_download=True,
92
- )
93
- print('Model download complete.')
94
- except Exception as e:
95
- print(f'ERROR during model download: {e}')
96
- pass # Allow to proceed to loading attempt, which will then fail more descriptively
97
-
98
  try:
99
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
100
- torch_dtype = torch.float16 if device == 'cuda' else torch.float32
101
- print(f'Target device: {device}, dtype: {torch_dtype}')
102
-
103
- print(f'Attempt 1: Loading tokenizer and model directly from Hub identifier: {HF_REPO} with trust_remote_code=True')
104
- try:
105
- tokenizer = AutoTokenizer.from_pretrained(
106
- HF_REPO,
107
- trust_remote_code=True
108
- )
109
- print('Tokenizer loaded successfully from Hub identifier.')
110
-
111
- config = AutoConfig.from_pretrained(
112
- HF_REPO,
113
- trust_remote_code=True
114
- )
115
- print('Config loaded successfully from Hub identifier.')
116
-
117
- llama_model = AutoModelForCausalLM.from_pretrained(
118
- HF_REPO,
119
- config=config, # Pass the loaded config
120
- torch_dtype=torch_dtype,
121
- device_map=device, # device_map handles moving parts of the model to CPU if OOM on GPU
122
- trust_remote_code=True
123
- )
124
- print(f'LLaMA-Omni2 model loaded successfully directly from Hub: {HF_REPO}')
125
- return llama_model
126
- except Exception as e1:
127
- print(f'Error in Attempt 1 (direct Hub load for {HF_REPO}): {e1}')
128
- print('This often means the model requires a specific transformers version or has complex remote code.')
129
-
130
- print(f'Attempt 2: Loading tokenizer and model from local path: {MODEL_PATH} with trust_remote_code=True (fallback)')
131
- try:
132
- tokenizer = AutoTokenizer.from_pretrained(
133
- MODEL_PATH, # Fallback to local path
134
- trust_remote_code=True
135
- )
136
- print('Tokenizer loaded successfully from local path.')
137
-
138
- config = AutoConfig.from_pretrained(
139
- MODEL_PATH,
140
- trust_remote_code=True
141
- )
142
- print('Config loaded successfully from local path.')
143
-
144
- llama_model = AutoModelForCausalLM.from_pretrained(
145
- MODEL_PATH, # Fallback to local path
146
- config=config,
147
- torch_dtype=torch_dtype,
148
- device_map=device,
149
- trust_remote_code=True
150
- )
151
- print(f'LLaMA-Omni2 model loaded successfully from local path: {MODEL_PATH}')
152
- return llama_model
153
- except Exception as e2:
154
- print(f'Error in Attempt 2 (local path load for {MODEL_PATH}): {e2}')
155
-
156
- print('All attempts to load the LLaMA-Omni2 model failed.')
157
- raise RuntimeError('Failed to load LLaMA-Omni2 model after multiple attempts.')
158
-
159
- except Exception as e_outer:
160
- print(f'CRITICAL ERROR loading LLaMA-Omni2 model: {e_outer}')
161
- print('Falling back: Text generation will not be available.')
162
- llama_model = None # Ensure llama_model is None if loading fails
163
- tokenizer = None # Ensure tokenizer is None
164
- return None
165
-
166
- def transcribe_audio(audio_path):
167
- '''Transcribe audio using Whisper'''
168
- global whisper_model
169
-
170
- if whisper_model is None:
171
- whisper_model = load_whisper_model()
172
-
173
- try:
174
- result = whisper_model.transcribe(audio_path)
175
- return result['text']
176
  except Exception as e:
177
- return f'Error transcribing audio: {e}'
178
-
179
- def generate_text(input_text):
180
- '''Generate text using LLaMA-Omni2'''
181
- global llama_model, tokenizer
182
-
183
- if llama_model is None or tokenizer is None:
184
- load_llama_model()
185
-
186
  try:
187
- # If model loading failed, just return a placeholder response
188
- if llama_model is None:
189
- return f'Model could not be loaded. Input was: {input_text}'
190
-
191
- device = next(llama_model.parameters()).device
192
- inputs = tokenizer(input_text, return_tensors='pt').to(device)
193
-
194
- outputs = llama_model.generate(
195
- inputs.input_ids,
196
- max_length=100,
197
- num_return_sequences=1,
198
- do_sample=True,
199
- temperature=0.7,
200
- )
201
-
202
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
203
  except Exception as e:
204
- return f'Error generating text: {e}'
 
205
 
206
- def speech_to_text_to_speech(audio_path):
207
- '''Pipeline: Speech -> Text -> Response'''
208
- # First transcribe the audio
209
- transcription = transcribe_audio(audio_path)
210
-
211
- # Then generate a response
212
- response = generate_text(transcription)
213
-
214
- return transcription, response
215
 
216
- # --- Gradio Interface for Hugging Face Spaces ---
217
- def create_demo():
218
- with gr.Blocks(title='LLaMA-Omni2 Demo on Hugging Face Spaces') as demo:
219
- gr.Markdown('# LLaMA-Omni2 Demo')
220
- gr.Markdown('This demo uses the smallest Whisper model and LLaMA-Omni2-0.5B for testing purposes.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
- with gr.Tab('Text Generation'):
223
- with gr.Row():
224
- text_input = gr.Textbox(label='Input Text', placeholder='Enter text here...')
225
- text_output = gr.Textbox(label='Generated Response')
226
-
227
- text_button = gr.Button('Generate Response')
228
- text_button.click(generate_text, inputs=text_input, outputs=text_output)
229
 
230
- with gr.Tab('Speech-to-Text'):
231
- audio_input = gr.Audio(type='filepath', label='Upload or Record Audio')
232
- transcription_output = gr.Textbox(label='Transcription')
233
- response_output = gr.Textbox(label='Generated Response')
234
-
235
- transcribe_button = gr.Button('Transcribe and Respond')
236
- transcribe_button.click(speech_to_text_to_speech,
237
- inputs=audio_input,
238
- outputs=[transcription_output, response_output])
 
239
 
240
- gr.Markdown('### Note: The first run will download models if needed, which may take some time.')
241
-
242
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
- # --- Main entry point ---
245
- if __name__ == '__main__':
246
- print('Starting LLaMA-Omni2 Interface for Hugging Face Spaces...')
 
 
 
 
 
 
 
 
247
 
248
- # Create and launch the Gradio interface
249
- demo = create_demo()
250
- demo.launch(server_name='0.0.0.0', server_port=7860, share=True) # share=True for Hugging Face Spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
2
  import torch
3
+ from transformers import pipeline
4
+ import os # os is imported but not used. Consider removing if not needed.
5
+
6
+ # --- Model Configuration ---
7
+ whisper_model_id = "openai/whisper-tiny"
8
+ # Using gpt2 as a placeholder due to LLaMA-Omni2-0.5B's complex setup needs.
9
+ # LLaMA-Omni2-0.5B (ICTNLP/LLaMA-Omni2-0.5B) is a speech-language model
10
+ # requiring specific dependencies (e.g., CosyVoice) and often its own serving infrastructure.
11
+ # It's not typically loaded via a simple transformers.pipeline for text generation alone.
12
+ text_generation_model_id = "gpt2"
13
+
14
+ # --- Device Configuration ---
15
+ if torch.cuda.is_available():
16
+ device_for_pipelines = 0 # Use the first GPU for Hugging Face pipelines
17
+ torch_device = "cuda:0" # PyTorch device string
18
+ # For models that support it and where precision is not critical, float16 can save memory/speed up.
19
+ # However, Whisper models are often more robust with float32 for pipeline usage unless memory is very constrained.
20
+ # GPT-2 also generally runs fine on float32 and doesn't strictly need float16 for basic use.
21
+ dtype_for_pipelines = torch.float16 # or torch.float32 depending on model/GPU
22
+ else:
23
+ device_for_pipelines = -1 # Use CPU for Hugging Face pipelines
24
+ torch_device = "cpu"
25
+ dtype_for_pipelines = torch.float32
26
+
27
+ print(f"Using device: {torch_device} for model loading.")
28
+ print(f"Pipelines will use device_id: {device_for_pipelines} and dtype: {dtype_for_pipelines}")
29
+
30
+ # --- Load Speech-to-Text (ASR) Pipeline ---
31
+ asr_pipeline_instance = None
32
  try:
33
+ print(f"Loading ASR model: {whisper_model_id}...")
34
+ asr_pipeline_instance = pipeline(
35
+ "automatic-speech-recognition",
36
+ model=whisper_model_id,
37
+ torch_dtype=dtype_for_pipelines, # Using specified dtype
38
+ device=device_for_pipelines
39
+ )
40
+ print(f"ASR model ({whisper_model_id}) loaded successfully.")
 
 
41
  except Exception as e:
42
+ print(f"Error loading ASR model ({whisper_model_id}): {e}")
43
+ asr_pipeline_instance = None # Explicitly set to None on failure
44
 
45
+ # --- Load Text Generation Pipeline ---
46
+ text_gen_pipeline_instance = None
47
  try:
48
+ print(f"Loading text generation model: {text_generation_model_id}...")
49
+ text_gen_pipeline_instance = pipeline(
50
+ "text-generation",
51
+ model=text_generation_model_id,
52
+ torch_dtype=dtype_for_pipelines, # Using specified dtype
53
+ device=device_for_pipelines
54
+ )
55
+ print(f"Text generation model ({text_generation_model_id}) loaded successfully.")
 
56
  except Exception as e:
57
+ print(f"Error loading text generation model ({text_generation_model_id}): {e}")
58
+ text_gen_pipeline_instance = None # Explicitly set to None on failure
59
+
60
+ # --- Core Functions ---
61
+ def transcribe_audio_input(audio_filepath):
62
+ if not asr_pipeline_instance:
63
+ return "ASR model not available. Please check startup logs.", ""
64
+ if audio_filepath is None:
65
+ return "No audio file provided for transcription.", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  try:
67
+ print(f"Transcribing: {audio_filepath}")
68
+ # Add chunk_length_s for handling longer audio files robustly
69
+ result = asr_pipeline_instance(audio_filepath, chunk_length_s=30)
70
+ transcribed_text = result["text"]
71
+ print(f"Transcription: '{transcribed_text}'")
72
+ return transcribed_text, transcribed_text # Return for UI and next step
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  except Exception as e:
74
+ print(f"Transcription error: {e}")
75
+ return f"Error during transcription: {str(e)}", ""
76
+
77
+ def generate_text_response(prompt_text):
78
+ if not text_gen_pipeline_instance:
79
+ return f"Text generation model ({text_generation_model_id}) not available. Check logs."
80
+ if not prompt_text or not prompt_text.strip():
81
+ return "Prompt is empty. Please provide text for generation."
 
82
  try:
83
+ print(f"Generating response for prompt (first 100 chars): '{prompt_text[:100]}...'")
84
+ # max_new_tokens is generally preferred over max_length for more control
85
+ generated_outputs = text_gen_pipeline_instance(prompt_text, max_new_tokens=100, num_return_sequences=1)
86
+ response_text = generated_outputs[0]["generated_text"]
87
+ print(f"Generated response: '{response_text}'")
88
+ return response_text
 
 
 
 
 
 
 
 
 
 
89
  except Exception as e:
90
+ print(f"Text generation error: {e}")
91
+ return f"Error during text generation: {str(e)}"
92
 
93
+ def combined_pipeline_process(audio_filepath):
94
+ if audio_filepath is None:
95
+ return "No audio input.", "No audio input."
 
 
 
 
 
 
96
 
97
+ transcribed_text, _ = transcribe_audio_input(audio_filepath)
98
+
99
+ if not asr_pipeline_instance or "Error during transcription" in transcribed_text or not transcribed_text.strip():
100
+ error_msg_for_generation = "Cannot generate response: Transcription failed or was empty."
101
+ if not asr_pipeline_instance:
102
+ error_msg_for_generation = "Cannot generate response: ASR model not loaded."
103
+ return transcribed_text, error_msg_for_generation
104
+
105
+ if not text_gen_pipeline_instance:
106
+ return transcribed_text, f"Cannot generate response: Text generation model ({text_generation_model_id}) not loaded."
107
+
108
+ final_response = generate_text_response(transcribed_text)
109
+ return transcribed_text, final_response
110
+
111
+ # --- Gradio Interface Definition ---
112
+ with gr.Blocks(theme=gr.themes.Soft(), title="Audio to Text Generation") as app_interface:
113
+ gr.Markdown(
114
+ """
115
+ # Speech-to-Text and Text Generation Demo
116
 
117
+ This application uses **OpenAI Whisper Tiny** for speech recognition and **GPT-2** (as a stand-in for more complex models like LLaMA-Omni2) for text generation.
118
+ You can upload an audio file, have it transcribed, and then use that transcription as a prompt to generate further text.
 
 
 
 
 
119
 
120
+ **Note on LLaMA-Omni2-0.5B:** The `ICTNLP/LLaMA-Omni2-0.5B` model is a sophisticated speech-language model designed for real-time spoken chat, generating both text and speech. It requires a specific setup environment (including its own speech synthesis like CosyVoice and potentially a dedicated serving mechanism). It's not plug-and-play with a simple `transformers.pipeline` in the same way standard ASR or text-only LLMs are. Therefore, GPT-2 is used here to demonstrate the Gradio app structure.
121
+ """
122
+ )
123
+
124
+ with gr.Tab("Full Pipeline: Audio -> Transcription -> Generation"):
125
+ gr.Markdown("### Step 1: Upload Audio -> Step 2: Transcribe -> Step 3: Generate Text")
126
+ input_audio_pipeline = gr.Audio(type="filepath", label="Upload Your Audio File (.wav, .mp3)")
127
+ submit_button_full = gr.Button("Run Full Process", variant="primary")
128
+ output_transcription_pipeline = gr.Textbox(label="Transcribed Text (from Whisper)", lines=5)
129
+ output_generation_pipeline = gr.Textbox(label=f"Generated Text (from {text_generation_model_id})", lines=7)
130
 
131
+ submit_button_full.click(
132
+ fn=combined_pipeline_process,
133
+ inputs=[input_audio_pipeline],
134
+ outputs=[output_transcription_pipeline, output_generation_pipeline]
135
+ )
136
+
137
+ with gr.Tab("Test Speech-to-Text (Whisper Tiny)"):
138
+ gr.Markdown("### Transcribe audio to text using Whisper Tiny.")
139
+ input_audio_asr = gr.Audio(type="filepath", label="Upload Audio for ASR")
140
+ submit_button_asr = gr.Button("Transcribe Audio", variant="secondary")
141
+ output_transcription_asr = gr.Textbox(label="Transcription Result", lines=10)
142
+
143
+ def asr_only_ui(audio_file):
144
+ if audio_file is None: return "Please upload an audio file."
145
+ # The transcribe_audio_input returns two values; we only need the first for this UI.
146
+ transcription, _ = transcribe_audio_input(audio_file)
147
+ return transcription
148
+
149
+ submit_button_asr.click(
150
+ fn=asr_only_ui,
151
+ inputs=[input_audio_asr],
152
+ outputs=[output_transcription_asr]
153
+ )
154
 
155
+ with gr.Tab(f"Test Text Generation ({text_generation_model_id})"):
156
+ gr.Markdown(f"### Generate text from a prompt using {text_generation_model_id}.")
157
+ input_text_prompt_gen = gr.Textbox(label="Your Text Prompt", placeholder="Enter text here...", lines=5)
158
+ submit_button_gen = gr.Button("Generate Text", variant="secondary")
159
+ output_generation_gen = gr.Textbox(label="Generated Text Result", lines=10)
160
+
161
+ submit_button_gen.click(
162
+ fn=generate_text_response,
163
+ inputs=[input_text_prompt_gen],
164
+ outputs=[output_generation_gen]
165
+ )
166
 
167
+ gr.Markdown("--- ")
168
+ gr.Markdown("### Model Loading Status (at application start):")
169
+ asr_load_status = "Successfully Loaded" if asr_pipeline_instance else "Failed to Load (check console logs)"
170
+ text_gen_load_status = "Successfully Loaded" if text_gen_pipeline_instance else "Failed to Load (check console logs)"
171
+ gr.Markdown(f"* **Whisper Model ({whisper_model_id}):** `{asr_load_status}`")
172
+ gr.Markdown(f"* **Text Generation Model ({text_generation_model_id}):** `{text_gen_load_status}`")
173
+
174
+ # --- Launch the Gradio App ---
175
+ if __name__ == "__main__":
176
+ print("Attempting to launch Gradio application...")
177
+ # share=True is good for Hugging Face Spaces. For local, it's optional.
178
+ # For persistent public link when running locally (requires internet & can have security implications):
179
+ # app_interface.launch(share=True)
180
+ app_interface.launch()
181
+ print("Gradio application launched. Check your browser or console for the URL.")
requirements.txt CHANGED
@@ -16,5 +16,6 @@ shortuuid
16
  pydub
17
  ffmpeg-python
18
  huggingface_hub # For downloading models from HF Hub
 
19
 
20
  # fairseq and flash-attn are removed, expected to be handled by LLaMA-Omni2's setup via `pip install -e .` in Dockerfile
 
16
  pydub
17
  ffmpeg-python
18
  huggingface_hub # For downloading models from HF Hub
19
+ soundfile # To handle audio files if not using gr.Audio input directly for some reason
20
 
21
  # fairseq and flash-attn are removed, expected to be handled by LLaMA-Omni2's setup via `pip install -e .` in Dockerfile