llama-omni / app.py
marcosremar2's picture
edff
d478b16
raw
history blame
14.5 kB
import gradio as gr
import torch
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import os
import warnings
import importlib
import sys
import subprocess
# Check if we can import LLaMA-Omni2's modules
try_native_modules = True
native_llama_omni_available = False
native_modules_error = None
if try_native_modules:
try:
# Try importing LLaMA-Omni2 specific modules using subprocess to avoid crashing if imports fail
print("Checking for LLaMA-Omni2 native modules...")
module_check_result = subprocess.run(
[sys.executable, "-c", "import llama_omni2; print('LLaMA-Omni2 modules found!')"],
capture_output=True,
text=True
)
if "LLaMA-Omni2 modules found!" in module_check_result.stdout:
print("LLaMA-Omni2 native modules are available!")
native_llama_omni_available = True
else:
print(f"LLaMA-Omni2 native modules not found: {module_check_result.stderr}")
native_modules_error = module_check_result.stderr
except Exception as e:
print(f"Error checking for LLaMA-Omni2 native modules: {e}")
native_modules_error = str(e)
# --- Model Configuration ---
whisper_model_id = "openai/whisper-tiny"
llama_omni_model_id = "ICTNLP/LLaMA-Omni2-0.5B" # Primary model we'll try to load
fallback_model_id = "gpt2" # Fallback if LLaMA-Omni2 fails to load
# --- Device Configuration ---
if torch.cuda.is_available():
device_for_pipelines = 0 # Use the first GPU for Hugging Face pipelines
torch_device = "cuda:0" # PyTorch device string
dtype_for_pipelines = torch.float16
else:
device_for_pipelines = -1 # Use CPU for Hugging Face pipelines
torch_device = "cpu"
dtype_for_pipelines = torch.float32
print(f"Using device: {torch_device} for model loading.")
print(f"Pipelines will use device_id: {device_for_pipelines} and dtype: {dtype_for_pipelines}")
# --- Load Speech-to-Text (ASR) Pipeline ---
asr_pipeline_instance = None
try:
print(f"Loading ASR model: {whisper_model_id}...")
asr_pipeline_instance = pipeline(
"automatic-speech-recognition",
model=whisper_model_id,
torch_dtype=dtype_for_pipelines,
device=device_for_pipelines
)
print(f"ASR model ({whisper_model_id}) loaded successfully.")
except Exception as e:
print(f"Error loading ASR model ({whisper_model_id}): {e}")
asr_pipeline_instance = None
# --- Load Text Generation Model ---
text_gen_pipeline_instance = None
text_generation_model_id = None # Will be set to the model that successfully loads
llama_omni_native_module = None # Will hold the native LLaMA-Omni2 module if loaded
# Try native LLaMA-Omni2 module first if available
if native_llama_omni_available:
try:
print("Attempting to load LLaMA-Omni2 using native modules...")
# Import the required modules
import llama_omni2
from llama_omni2.model import Model as LLamaOmniModel
# Load the model
llama_omni_native_module = LLamaOmniModel.from_pretrained(llama_omni_model_id)
text_generation_model_id = llama_omni_model_id
print(f"LLaMA-Omni2 native module loaded successfully: {type(llama_omni_native_module)}")
except Exception as e:
print(f"Error loading native LLaMA-Omni2 module: {e}")
llama_omni_native_module = None
# If native module failed, try loading using transformers
if llama_omni_native_module is None and text_generation_model_id is None:
try:
print(f"Attempting to load LLaMA-Omni2 using transformers: {llama_omni_model_id}...")
# LLaMA models often require specific loading configurations
tokenizer = AutoTokenizer.from_pretrained(llama_omni_model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
llama_omni_model_id,
torch_dtype=dtype_for_pipelines,
trust_remote_code=True,
device_map="auto" if torch.cuda.is_available() else None
)
text_gen_pipeline_instance = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=dtype_for_pipelines,
device=device_for_pipelines if not torch.cuda.is_available() else None
)
text_generation_model_id = llama_omni_model_id
print(f"LLaMA-Omni2 model ({llama_omni_model_id}) loaded successfully via transformers.")
except Exception as e:
warnings.warn(f"Error loading LLaMA-Omni2 model: {e}\nFalling back to {fallback_model_id}")
print(f"Error loading LLaMA-Omni2 model via transformers: {e}")
print(f"Falling back to {fallback_model_id}")
# Fall back to GPT-2 if LLaMA-Omni2 fails to load both ways
if text_generation_model_id is None:
try:
print(f"Loading fallback text generation model: {fallback_model_id}...")
text_gen_pipeline_instance = pipeline(
"text-generation",
model=fallback_model_id,
torch_dtype=dtype_for_pipelines,
device=device_for_pipelines
)
text_generation_model_id = fallback_model_id
print(f"Fallback model ({fallback_model_id}) loaded successfully.")
except Exception as e:
print(f"Error loading fallback model ({fallback_model_id}): {e}")
text_gen_pipeline_instance = None
# --- Core Functions ---
def transcribe_audio_input(audio_filepath):
if not asr_pipeline_instance:
return "ASR model not available. Please check startup logs.", ""
if audio_filepath is None:
return "No audio file provided for transcription.", ""
try:
print(f"Transcribing: {audio_filepath}")
result = asr_pipeline_instance(audio_filepath, chunk_length_s=30)
transcribed_text = result["text"]
print(f"Transcription: '{transcribed_text}'")
return transcribed_text, transcribed_text
except Exception as e:
print(f"Transcription error: {e}")
return f"Error during transcription: {str(e)}", ""
def generate_text_response(prompt_text):
# If we have a native LLaMA-Omni2 module, use it
if llama_omni_native_module is not None:
if not prompt_text or not prompt_text.strip():
return "Prompt is empty. Please provide text for generation."
try:
print(f"Generating response with native LLaMA-Omni2 for prompt: '{prompt_text[:100]}...'")
# Using the native module's interface for text generation
response = llama_omni_native_module.generate(prompt_text, max_length=150)
print(f"Generated response: '{response}'")
return response
except Exception as e:
print(f"Error using native LLaMA-Omni2 generation: {e}")
return f"Error during native LLaMA-Omni2 text generation: {str(e)}"
# Otherwise use the transformers pipeline
if not text_gen_pipeline_instance:
return f"Text generation model not available. Check logs."
if not prompt_text or not prompt_text.strip():
return "Prompt is empty. Please provide text for generation."
try:
print(f"Generating response for prompt (first 100 chars): '{prompt_text[:100]}...'")
# Different generation parameters based on model
if text_generation_model_id == llama_omni_model_id:
# Parameters optimized for LLaMA-Omni2
generated_outputs = text_gen_pipeline_instance(
prompt_text,
max_new_tokens=150,
do_sample=True,
temperature=0.7,
top_p=0.9,
num_return_sequences=1
)
else:
# Parameters for fallback model
generated_outputs = text_gen_pipeline_instance(
prompt_text,
max_new_tokens=100,
num_return_sequences=1
)
response_text = generated_outputs[0]["generated_text"]
print(f"Generated response: '{response_text}'")
return response_text
except Exception as e:
print(f"Text generation error: {e}")
return f"Error during text generation: {str(e)}"
def combined_pipeline_process(audio_filepath):
if audio_filepath is None:
return "No audio input.", "No audio input."
transcribed_text, _ = transcribe_audio_input(audio_filepath)
if not asr_pipeline_instance or "Error during transcription" in transcribed_text or not transcribed_text.strip():
error_msg_for_generation = "Cannot generate response: Transcription failed or was empty."
if not asr_pipeline_instance:
error_msg_for_generation = "Cannot generate response: ASR model not loaded."
return transcribed_text, error_msg_for_generation
if not text_gen_pipeline_instance and llama_omni_native_module is None:
return transcribed_text, f"Cannot generate response: No text generation model available."
final_response = generate_text_response(transcribed_text)
return transcribed_text, final_response
# Determine model status for UI
if llama_omni_native_module is not None:
llama_model_status = "Native LLaMA-Omni2 module loaded successfully"
using_model = "LLaMA-Omni2-0.5B (native modules)"
elif text_generation_model_id == llama_omni_model_id:
llama_model_status = "LLaMA-Omni2 loaded via transformers"
using_model = "LLaMA-Omni2-0.5B (via transformers)"
elif text_generation_model_id == fallback_model_id:
llama_model_status = "Failed to load - Using GPT-2 as fallback"
using_model = "GPT-2 (fallback model)"
else:
llama_model_status = "Failed to load any text generation model"
using_model = "No model available"
# --- Gradio Interface Definition ---
with gr.Blocks(theme=gr.themes.Soft(), title="Whisper + LLaMA-Omni2 Demo") as app_interface:
gr.Markdown(
f"""
# Speech-to-Text and Text Generation Demo
This application uses **OpenAI Whisper Tiny** for speech recognition and attempts to use **LLaMA-Omni2-0.5B** for text generation.
If LLaMA-Omni2 cannot be loaded, it falls back to GPT-2.
**Currently using:** {using_model}
Upload an audio file to transcribe it. The transcribed text will then be used as a prompt for the text generation model.
"""
)
with gr.Tab("Full Pipeline: Audio -> Transcription -> Generation"):
gr.Markdown("### Step 1: Upload Audio -> Step 2: Transcribe -> Step 3: Generate Text")
input_audio_pipeline = gr.Audio(type="filepath", label="Upload Your Audio File (.wav, .mp3)")
submit_button_full = gr.Button("Run Full Process", variant="primary")
output_transcription_pipeline = gr.Textbox(label="Transcribed Text (from Whisper)", lines=5)
model_label = f"Generated Text (from {using_model})"
output_generation_pipeline = gr.Textbox(label=model_label, lines=7)
submit_button_full.click(
fn=combined_pipeline_process,
inputs=[input_audio_pipeline],
outputs=[output_transcription_pipeline, output_generation_pipeline]
)
with gr.Tab("Test Speech-to-Text (Whisper Tiny)"):
gr.Markdown("### Transcribe audio to text using Whisper Tiny.")
input_audio_asr = gr.Audio(type="filepath", label="Upload Audio for ASR")
submit_button_asr = gr.Button("Transcribe Audio", variant="secondary")
output_transcription_asr = gr.Textbox(label="Transcription Result", lines=10)
def asr_only_ui(audio_file):
if audio_file is None: return "Please upload an audio file."
transcription, _ = transcribe_audio_input(audio_file)
return transcription
submit_button_asr.click(
fn=asr_only_ui,
inputs=[input_audio_asr],
outputs=[output_transcription_asr]
)
with gr.Tab(f"Test Text Generation"):
model_name_gen = using_model
gr.Markdown(f"### Generate text from a prompt using {model_name_gen}.")
input_text_prompt_gen = gr.Textbox(label="Your Text Prompt", placeholder="Enter text here...", lines=5)
submit_button_gen = gr.Button("Generate Text", variant="secondary")
output_generation_gen = gr.Textbox(label="Generated Text Result", lines=10)
submit_button_gen.click(
fn=generate_text_response,
inputs=[input_text_prompt_gen],
outputs=[output_generation_gen]
)
gr.Markdown("--- ")
gr.Markdown("### Model Loading Status (at application start):")
asr_load_status = "Successfully Loaded" if asr_pipeline_instance else "Failed to Load (check console logs)"
gr.Markdown(f"* **Whisper Model ({whisper_model_id}):** `{asr_load_status}`")
gr.Markdown(f"* **LLaMA-Omni2 Model ({llama_omni_model_id}):** `{llama_model_status}`")
if native_llama_omni_available:
gr.Markdown("* **LLaMA-Omni2 Native Modules:** `Available`")
else:
native_error = f": {native_modules_error}" if native_modules_error else ""
gr.Markdown(f"* **LLaMA-Omni2 Native Modules:** `Not Available{native_error}`")
if using_model.startswith("GPT-2"):
gr.Markdown(
"""
**Note about LLaMA-Omni2-0.5B:** This model has complex dependencies and requires a specific setup environment.
The system attempted to load it but fell back to GPT-2. For full functionality with LLaMA-Omni2, you should:
1. Clone the [LLaMA-Omni2 repository](https://github.com/ictnlp/LLaMA-Omni2)
2. Install the required dependencies including CosyVoice 2
3. Download the Whisper-large-v3 model and flow-matching model and vocoder of CosyVoice 2
4. Set up the controller, model worker, and web server as described in the repository
Note that LLaMA-Omni2 is designed for generating both text and speech responses simultaneously.
For the full experience with speech synthesis, you need the complete setup.
"""
)
# --- Launch the Gradio App ---
if __name__ == "__main__":
print("Launching Gradio demo...")
try:
app_interface.launch(share=True)
except Exception as e:
print(f"Error launching with share=True: {e}")
print("Trying to launch without sharing...")
app_interface.launch()