llama-omni / app.py
marcosremar2's picture
rerer
b9d0632
raw
history blame
9.47 kB
import gradio as gr
import torch
from transformers import pipeline
import os # os is imported but not used. Consider removing if not needed.
# --- Model Configuration ---
whisper_model_id = "openai/whisper-tiny"
# Using gpt2 as a placeholder due to LLaMA-Omni2-0.5B's complex setup needs.
# LLaMA-Omni2-0.5B (ICTNLP/LLaMA-Omni2-0.5B) is a speech-language model
# requiring specific dependencies (e.g., CosyVoice) and often its own serving infrastructure.
# It's not typically loaded via a simple transformers.pipeline for text generation alone.
text_generation_model_id = "gpt2"
# --- Device Configuration ---
if torch.cuda.is_available():
device_for_pipelines = 0 # Use the first GPU for Hugging Face pipelines
torch_device = "cuda:0" # PyTorch device string
# For models that support it and where precision is not critical, float16 can save memory/speed up.
# However, Whisper models are often more robust with float32 for pipeline usage unless memory is very constrained.
# GPT-2 also generally runs fine on float32 and doesn't strictly need float16 for basic use.
dtype_for_pipelines = torch.float16 # or torch.float32 depending on model/GPU
else:
device_for_pipelines = -1 # Use CPU for Hugging Face pipelines
torch_device = "cpu"
dtype_for_pipelines = torch.float32
print(f"Using device: {torch_device} for model loading.")
print(f"Pipelines will use device_id: {device_for_pipelines} and dtype: {dtype_for_pipelines}")
# --- Load Speech-to-Text (ASR) Pipeline ---
asr_pipeline_instance = None
try:
print(f"Loading ASR model: {whisper_model_id}...")
asr_pipeline_instance = pipeline(
"automatic-speech-recognition",
model=whisper_model_id,
torch_dtype=dtype_for_pipelines, # Using specified dtype
device=device_for_pipelines
)
print(f"ASR model ({whisper_model_id}) loaded successfully.")
except Exception as e:
print(f"Error loading ASR model ({whisper_model_id}): {e}")
asr_pipeline_instance = None # Explicitly set to None on failure
# --- Load Text Generation Pipeline ---
text_gen_pipeline_instance = None
try:
print(f"Loading text generation model: {text_generation_model_id}...")
text_gen_pipeline_instance = pipeline(
"text-generation",
model=text_generation_model_id,
torch_dtype=dtype_for_pipelines, # Using specified dtype
device=device_for_pipelines
)
print(f"Text generation model ({text_generation_model_id}) loaded successfully.")
except Exception as e:
print(f"Error loading text generation model ({text_generation_model_id}): {e}")
text_gen_pipeline_instance = None # Explicitly set to None on failure
# --- Core Functions ---
def transcribe_audio_input(audio_filepath):
if not asr_pipeline_instance:
return "ASR model not available. Please check startup logs.", ""
if audio_filepath is None:
return "No audio file provided for transcription.", ""
try:
print(f"Transcribing: {audio_filepath}")
# Add chunk_length_s for handling longer audio files robustly
result = asr_pipeline_instance(audio_filepath, chunk_length_s=30)
transcribed_text = result["text"]
print(f"Transcription: '{transcribed_text}'")
return transcribed_text, transcribed_text # Return for UI and next step
except Exception as e:
print(f"Transcription error: {e}")
return f"Error during transcription: {str(e)}", ""
def generate_text_response(prompt_text):
if not text_gen_pipeline_instance:
return f"Text generation model ({text_generation_model_id}) not available. Check logs."
if not prompt_text or not prompt_text.strip():
return "Prompt is empty. Please provide text for generation."
try:
print(f"Generating response for prompt (first 100 chars): '{prompt_text[:100]}...'")
# max_new_tokens is generally preferred over max_length for more control
generated_outputs = text_gen_pipeline_instance(prompt_text, max_new_tokens=100, num_return_sequences=1)
response_text = generated_outputs[0]["generated_text"]
print(f"Generated response: '{response_text}'")
return response_text
except Exception as e:
print(f"Text generation error: {e}")
return f"Error during text generation: {str(e)}"
def combined_pipeline_process(audio_filepath):
if audio_filepath is None:
return "No audio input.", "No audio input."
transcribed_text, _ = transcribe_audio_input(audio_filepath)
if not asr_pipeline_instance or "Error during transcription" in transcribed_text or not transcribed_text.strip():
error_msg_for_generation = "Cannot generate response: Transcription failed or was empty."
if not asr_pipeline_instance:
error_msg_for_generation = "Cannot generate response: ASR model not loaded."
return transcribed_text, error_msg_for_generation
if not text_gen_pipeline_instance:
return transcribed_text, f"Cannot generate response: Text generation model ({text_generation_model_id}) not loaded."
final_response = generate_text_response(transcribed_text)
return transcribed_text, final_response
# --- Gradio Interface Definition ---
with gr.Blocks(theme=gr.themes.Soft(), title="Audio to Text Generation") as app_interface:
gr.Markdown(
"""
# Speech-to-Text and Text Generation Demo
This application uses **OpenAI Whisper Tiny** for speech recognition and **GPT-2** (as a stand-in for more complex models like LLaMA-Omni2) for text generation.
You can upload an audio file, have it transcribed, and then use that transcription as a prompt to generate further text.
**Note on LLaMA-Omni2-0.5B:** The `ICTNLP/LLaMA-Omni2-0.5B` model is a sophisticated speech-language model designed for real-time spoken chat, generating both text and speech. It requires a specific setup environment (including its own speech synthesis like CosyVoice and potentially a dedicated serving mechanism). It's not plug-and-play with a simple `transformers.pipeline` in the same way standard ASR or text-only LLMs are. Therefore, GPT-2 is used here to demonstrate the Gradio app structure.
"""
)
with gr.Tab("Full Pipeline: Audio -> Transcription -> Generation"):
gr.Markdown("### Step 1: Upload Audio -> Step 2: Transcribe -> Step 3: Generate Text")
input_audio_pipeline = gr.Audio(type="filepath", label="Upload Your Audio File (.wav, .mp3)")
submit_button_full = gr.Button("Run Full Process", variant="primary")
output_transcription_pipeline = gr.Textbox(label="Transcribed Text (from Whisper)", lines=5)
output_generation_pipeline = gr.Textbox(label=f"Generated Text (from {text_generation_model_id})", lines=7)
submit_button_full.click(
fn=combined_pipeline_process,
inputs=[input_audio_pipeline],
outputs=[output_transcription_pipeline, output_generation_pipeline]
)
with gr.Tab("Test Speech-to-Text (Whisper Tiny)"):
gr.Markdown("### Transcribe audio to text using Whisper Tiny.")
input_audio_asr = gr.Audio(type="filepath", label="Upload Audio for ASR")
submit_button_asr = gr.Button("Transcribe Audio", variant="secondary")
output_transcription_asr = gr.Textbox(label="Transcription Result", lines=10)
def asr_only_ui(audio_file):
if audio_file is None: return "Please upload an audio file."
# The transcribe_audio_input returns two values; we only need the first for this UI.
transcription, _ = transcribe_audio_input(audio_file)
return transcription
submit_button_asr.click(
fn=asr_only_ui,
inputs=[input_audio_asr],
outputs=[output_transcription_asr]
)
with gr.Tab(f"Test Text Generation ({text_generation_model_id})"):
gr.Markdown(f"### Generate text from a prompt using {text_generation_model_id}.")
input_text_prompt_gen = gr.Textbox(label="Your Text Prompt", placeholder="Enter text here...", lines=5)
submit_button_gen = gr.Button("Generate Text", variant="secondary")
output_generation_gen = gr.Textbox(label="Generated Text Result", lines=10)
submit_button_gen.click(
fn=generate_text_response,
inputs=[input_text_prompt_gen],
outputs=[output_generation_gen]
)
gr.Markdown("--- ")
gr.Markdown("### Model Loading Status (at application start):")
asr_load_status = "Successfully Loaded" if asr_pipeline_instance else "Failed to Load (check console logs)"
text_gen_load_status = "Successfully Loaded" if text_gen_pipeline_instance else "Failed to Load (check console logs)"
gr.Markdown(f"* **Whisper Model ({whisper_model_id}):** `{asr_load_status}`")
gr.Markdown(f"* **Text Generation Model ({text_generation_model_id}):** `{text_gen_load_status}`")
# --- Launch the Gradio App ---
if __name__ == "__main__":
print("Attempting to launch Gradio application...")
# share=True is good for Hugging Face Spaces. For local, it's optional.
# For persistent public link when running locally (requires internet & can have security implications):
# app_interface.launch(share=True)
app_interface.launch()
print("Gradio application launched. Check your browser or console for the URL.")