Spaces:
Build error
Build error
import gradio as gr | |
import torch | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
import os | |
import warnings | |
import importlib | |
import sys | |
import subprocess | |
# Check if we can import LLaMA-Omni2's modules | |
try_native_modules = True | |
native_llama_omni_available = False | |
native_modules_error = None | |
if try_native_modules: | |
try: | |
# Try importing LLaMA-Omni2 specific modules using subprocess to avoid crashing if imports fail | |
print("Checking for LLaMA-Omni2 native modules...") | |
module_check_result = subprocess.run( | |
[sys.executable, "-c", "import llama_omni2; print('LLaMA-Omni2 modules found!')"], | |
capture_output=True, | |
text=True | |
) | |
if "LLaMA-Omni2 modules found!" in module_check_result.stdout: | |
print("LLaMA-Omni2 native modules are available!") | |
native_llama_omni_available = True | |
else: | |
print(f"LLaMA-Omni2 native modules not found: {module_check_result.stderr}") | |
native_modules_error = module_check_result.stderr | |
except Exception as e: | |
print(f"Error checking for LLaMA-Omni2 native modules: {e}") | |
native_modules_error = str(e) | |
# --- Model Configuration --- | |
whisper_model_id = "openai/whisper-tiny" | |
llama_omni_model_id = "ICTNLP/LLaMA-Omni2-0.5B" # Primary model we'll try to load | |
fallback_model_id = "gpt2" # Fallback if LLaMA-Omni2 fails to load | |
# --- Device Configuration --- | |
if torch.cuda.is_available(): | |
device_for_pipelines = 0 # Use the first GPU for Hugging Face pipelines | |
torch_device = "cuda:0" # PyTorch device string | |
dtype_for_pipelines = torch.float16 | |
else: | |
device_for_pipelines = -1 # Use CPU for Hugging Face pipelines | |
torch_device = "cpu" | |
dtype_for_pipelines = torch.float32 | |
print(f"Using device: {torch_device} for model loading.") | |
print(f"Pipelines will use device_id: {device_for_pipelines} and dtype: {dtype_for_pipelines}") | |
# --- Load Speech-to-Text (ASR) Pipeline --- | |
asr_pipeline_instance = None | |
try: | |
print(f"Loading ASR model: {whisper_model_id}...") | |
asr_pipeline_instance = pipeline( | |
"automatic-speech-recognition", | |
model=whisper_model_id, | |
torch_dtype=dtype_for_pipelines, | |
device=device_for_pipelines | |
) | |
print(f"ASR model ({whisper_model_id}) loaded successfully.") | |
except Exception as e: | |
print(f"Error loading ASR model ({whisper_model_id}): {e}") | |
asr_pipeline_instance = None | |
# --- Load Text Generation Model --- | |
text_gen_pipeline_instance = None | |
text_generation_model_id = None # Will be set to the model that successfully loads | |
llama_omni_native_module = None # Will hold the native LLaMA-Omni2 module if loaded | |
# Try native LLaMA-Omni2 module first if available | |
if native_llama_omni_available: | |
try: | |
print("Attempting to load LLaMA-Omni2 using native modules...") | |
# Import the required modules | |
import llama_omni2 | |
from llama_omni2.model import Model as LLamaOmniModel | |
# Load the model | |
llama_omni_native_module = LLamaOmniModel.from_pretrained(llama_omni_model_id) | |
text_generation_model_id = llama_omni_model_id | |
print(f"LLaMA-Omni2 native module loaded successfully: {type(llama_omni_native_module)}") | |
except Exception as e: | |
print(f"Error loading native LLaMA-Omni2 module: {e}") | |
llama_omni_native_module = None | |
# If native module failed, try loading using transformers | |
if llama_omni_native_module is None and text_generation_model_id is None: | |
try: | |
print(f"Attempting to load LLaMA-Omni2 using transformers: {llama_omni_model_id}...") | |
# LLaMA models often require specific loading configurations | |
tokenizer = AutoTokenizer.from_pretrained(llama_omni_model_id, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
llama_omni_model_id, | |
torch_dtype=dtype_for_pipelines, | |
trust_remote_code=True, | |
device_map="auto" if torch.cuda.is_available() else None | |
) | |
text_gen_pipeline_instance = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
torch_dtype=dtype_for_pipelines, | |
device=device_for_pipelines if not torch.cuda.is_available() else None | |
) | |
text_generation_model_id = llama_omni_model_id | |
print(f"LLaMA-Omni2 model ({llama_omni_model_id}) loaded successfully via transformers.") | |
except Exception as e: | |
warnings.warn(f"Error loading LLaMA-Omni2 model: {e}\nFalling back to {fallback_model_id}") | |
print(f"Error loading LLaMA-Omni2 model via transformers: {e}") | |
print(f"Falling back to {fallback_model_id}") | |
# Fall back to GPT-2 if LLaMA-Omni2 fails to load both ways | |
if text_generation_model_id is None: | |
try: | |
print(f"Loading fallback text generation model: {fallback_model_id}...") | |
text_gen_pipeline_instance = pipeline( | |
"text-generation", | |
model=fallback_model_id, | |
torch_dtype=dtype_for_pipelines, | |
device=device_for_pipelines | |
) | |
text_generation_model_id = fallback_model_id | |
print(f"Fallback model ({fallback_model_id}) loaded successfully.") | |
except Exception as e: | |
print(f"Error loading fallback model ({fallback_model_id}): {e}") | |
text_gen_pipeline_instance = None | |
# --- Core Functions --- | |
def transcribe_audio_input(audio_filepath): | |
if not asr_pipeline_instance: | |
return "ASR model not available. Please check startup logs.", "" | |
if audio_filepath is None: | |
return "No audio file provided for transcription.", "" | |
try: | |
print(f"Transcribing: {audio_filepath}") | |
result = asr_pipeline_instance(audio_filepath, chunk_length_s=30) | |
transcribed_text = result["text"] | |
print(f"Transcription: '{transcribed_text}'") | |
return transcribed_text, transcribed_text | |
except Exception as e: | |
print(f"Transcription error: {e}") | |
return f"Error during transcription: {str(e)}", "" | |
def generate_text_response(prompt_text): | |
# If we have a native LLaMA-Omni2 module, use it | |
if llama_omni_native_module is not None: | |
if not prompt_text or not prompt_text.strip(): | |
return "Prompt is empty. Please provide text for generation." | |
try: | |
print(f"Generating response with native LLaMA-Omni2 for prompt: '{prompt_text[:100]}...'") | |
# Using the native module's interface for text generation | |
response = llama_omni_native_module.generate(prompt_text, max_length=150) | |
print(f"Generated response: '{response}'") | |
return response | |
except Exception as e: | |
print(f"Error using native LLaMA-Omni2 generation: {e}") | |
return f"Error during native LLaMA-Omni2 text generation: {str(e)}" | |
# Otherwise use the transformers pipeline | |
if not text_gen_pipeline_instance: | |
return f"Text generation model not available. Check logs." | |
if not prompt_text or not prompt_text.strip(): | |
return "Prompt is empty. Please provide text for generation." | |
try: | |
print(f"Generating response for prompt (first 100 chars): '{prompt_text[:100]}...'") | |
# Different generation parameters based on model | |
if text_generation_model_id == llama_omni_model_id: | |
# Parameters optimized for LLaMA-Omni2 | |
generated_outputs = text_gen_pipeline_instance( | |
prompt_text, | |
max_new_tokens=150, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
num_return_sequences=1 | |
) | |
else: | |
# Parameters for fallback model | |
generated_outputs = text_gen_pipeline_instance( | |
prompt_text, | |
max_new_tokens=100, | |
num_return_sequences=1 | |
) | |
response_text = generated_outputs[0]["generated_text"] | |
print(f"Generated response: '{response_text}'") | |
return response_text | |
except Exception as e: | |
print(f"Text generation error: {e}") | |
return f"Error during text generation: {str(e)}" | |
def combined_pipeline_process(audio_filepath): | |
if audio_filepath is None: | |
return "No audio input.", "No audio input." | |
transcribed_text, _ = transcribe_audio_input(audio_filepath) | |
if not asr_pipeline_instance or "Error during transcription" in transcribed_text or not transcribed_text.strip(): | |
error_msg_for_generation = "Cannot generate response: Transcription failed or was empty." | |
if not asr_pipeline_instance: | |
error_msg_for_generation = "Cannot generate response: ASR model not loaded." | |
return transcribed_text, error_msg_for_generation | |
if not text_gen_pipeline_instance and llama_omni_native_module is None: | |
return transcribed_text, f"Cannot generate response: No text generation model available." | |
final_response = generate_text_response(transcribed_text) | |
return transcribed_text, final_response | |
# Determine model status for UI | |
if llama_omni_native_module is not None: | |
llama_model_status = "Native LLaMA-Omni2 module loaded successfully" | |
using_model = "LLaMA-Omni2-0.5B (native modules)" | |
elif text_generation_model_id == llama_omni_model_id: | |
llama_model_status = "LLaMA-Omni2 loaded via transformers" | |
using_model = "LLaMA-Omni2-0.5B (via transformers)" | |
elif text_generation_model_id == fallback_model_id: | |
llama_model_status = "Failed to load - Using GPT-2 as fallback" | |
using_model = "GPT-2 (fallback model)" | |
else: | |
llama_model_status = "Failed to load any text generation model" | |
using_model = "No model available" | |
# --- Gradio Interface Definition --- | |
with gr.Blocks(theme=gr.themes.Soft(), title="Whisper + LLaMA-Omni2 Demo") as app_interface: | |
gr.Markdown( | |
f""" | |
# Speech-to-Text and Text Generation Demo | |
This application uses **OpenAI Whisper Tiny** for speech recognition and attempts to use **LLaMA-Omni2-0.5B** for text generation. | |
If LLaMA-Omni2 cannot be loaded, it falls back to GPT-2. | |
**Currently using:** {using_model} | |
Upload an audio file to transcribe it. The transcribed text will then be used as a prompt for the text generation model. | |
""" | |
) | |
with gr.Tab("Full Pipeline: Audio -> Transcription -> Generation"): | |
gr.Markdown("### Step 1: Upload Audio -> Step 2: Transcribe -> Step 3: Generate Text") | |
input_audio_pipeline = gr.Audio(type="filepath", label="Upload Your Audio File (.wav, .mp3)") | |
submit_button_full = gr.Button("Run Full Process", variant="primary") | |
output_transcription_pipeline = gr.Textbox(label="Transcribed Text (from Whisper)", lines=5) | |
model_label = f"Generated Text (from {using_model})" | |
output_generation_pipeline = gr.Textbox(label=model_label, lines=7) | |
submit_button_full.click( | |
fn=combined_pipeline_process, | |
inputs=[input_audio_pipeline], | |
outputs=[output_transcription_pipeline, output_generation_pipeline] | |
) | |
with gr.Tab("Test Speech-to-Text (Whisper Tiny)"): | |
gr.Markdown("### Transcribe audio to text using Whisper Tiny.") | |
input_audio_asr = gr.Audio(type="filepath", label="Upload Audio for ASR") | |
submit_button_asr = gr.Button("Transcribe Audio", variant="secondary") | |
output_transcription_asr = gr.Textbox(label="Transcription Result", lines=10) | |
def asr_only_ui(audio_file): | |
if audio_file is None: return "Please upload an audio file." | |
transcription, _ = transcribe_audio_input(audio_file) | |
return transcription | |
submit_button_asr.click( | |
fn=asr_only_ui, | |
inputs=[input_audio_asr], | |
outputs=[output_transcription_asr] | |
) | |
with gr.Tab(f"Test Text Generation"): | |
model_name_gen = using_model | |
gr.Markdown(f"### Generate text from a prompt using {model_name_gen}.") | |
input_text_prompt_gen = gr.Textbox(label="Your Text Prompt", placeholder="Enter text here...", lines=5) | |
submit_button_gen = gr.Button("Generate Text", variant="secondary") | |
output_generation_gen = gr.Textbox(label="Generated Text Result", lines=10) | |
submit_button_gen.click( | |
fn=generate_text_response, | |
inputs=[input_text_prompt_gen], | |
outputs=[output_generation_gen] | |
) | |
gr.Markdown("--- ") | |
gr.Markdown("### Model Loading Status (at application start):") | |
asr_load_status = "Successfully Loaded" if asr_pipeline_instance else "Failed to Load (check console logs)" | |
gr.Markdown(f"* **Whisper Model ({whisper_model_id}):** `{asr_load_status}`") | |
gr.Markdown(f"* **LLaMA-Omni2 Model ({llama_omni_model_id}):** `{llama_model_status}`") | |
if native_llama_omni_available: | |
gr.Markdown("* **LLaMA-Omni2 Native Modules:** `Available`") | |
else: | |
native_error = f": {native_modules_error}" if native_modules_error else "" | |
gr.Markdown(f"* **LLaMA-Omni2 Native Modules:** `Not Available{native_error}`") | |
if using_model.startswith("GPT-2"): | |
gr.Markdown( | |
""" | |
**Note about LLaMA-Omni2-0.5B:** This model has complex dependencies and requires a specific setup environment. | |
The system attempted to load it but fell back to GPT-2. For full functionality with LLaMA-Omni2, you should: | |
1. Clone the [LLaMA-Omni2 repository](https://github.com/ictnlp/LLaMA-Omni2) | |
2. Install the required dependencies including CosyVoice 2 | |
3. Download the Whisper-large-v3 model and flow-matching model and vocoder of CosyVoice 2 | |
4. Set up the controller, model worker, and web server as described in the repository | |
Note that LLaMA-Omni2 is designed for generating both text and speech responses simultaneously. | |
For the full experience with speech synthesis, you need the complete setup. | |
""" | |
) | |
# --- Launch the Gradio App --- | |
if __name__ == "__main__": | |
print("Launching Gradio demo...") | |
try: | |
app_interface.launch(share=True) | |
except Exception as e: | |
print(f"Error launching with share=True: {e}") | |
print("Trying to launch without sharing...") | |
app_interface.launch() |