Spaces:
Build error
Build error
import gradio as gr | |
import torch | |
from transformers import pipeline | |
import os # os is imported but not used. Consider removing if not needed. | |
# --- Model Configuration --- | |
whisper_model_id = "openai/whisper-tiny" | |
# Using gpt2 as a placeholder due to LLaMA-Omni2-0.5B's complex setup needs. | |
# LLaMA-Omni2-0.5B (ICTNLP/LLaMA-Omni2-0.5B) is a speech-language model | |
# requiring specific dependencies (e.g., CosyVoice) and often its own serving infrastructure. | |
# It's not typically loaded via a simple transformers.pipeline for text generation alone. | |
text_generation_model_id = "gpt2" | |
# --- Device Configuration --- | |
if torch.cuda.is_available(): | |
device_for_pipelines = 0 # Use the first GPU for Hugging Face pipelines | |
torch_device = "cuda:0" # PyTorch device string | |
# For models that support it and where precision is not critical, float16 can save memory/speed up. | |
# However, Whisper models are often more robust with float32 for pipeline usage unless memory is very constrained. | |
# GPT-2 also generally runs fine on float32 and doesn't strictly need float16 for basic use. | |
dtype_for_pipelines = torch.float16 # or torch.float32 depending on model/GPU | |
else: | |
device_for_pipelines = -1 # Use CPU for Hugging Face pipelines | |
torch_device = "cpu" | |
dtype_for_pipelines = torch.float32 | |
print(f"Using device: {torch_device} for model loading.") | |
print(f"Pipelines will use device_id: {device_for_pipelines} and dtype: {dtype_for_pipelines}") | |
# --- Load Speech-to-Text (ASR) Pipeline --- | |
asr_pipeline_instance = None | |
try: | |
print(f"Loading ASR model: {whisper_model_id}...") | |
asr_pipeline_instance = pipeline( | |
"automatic-speech-recognition", | |
model=whisper_model_id, | |
torch_dtype=dtype_for_pipelines, # Using specified dtype | |
device=device_for_pipelines | |
) | |
print(f"ASR model ({whisper_model_id}) loaded successfully.") | |
except Exception as e: | |
print(f"Error loading ASR model ({whisper_model_id}): {e}") | |
asr_pipeline_instance = None # Explicitly set to None on failure | |
# --- Load Text Generation Pipeline --- | |
text_gen_pipeline_instance = None | |
try: | |
print(f"Loading text generation model: {text_generation_model_id}...") | |
text_gen_pipeline_instance = pipeline( | |
"text-generation", | |
model=text_generation_model_id, | |
torch_dtype=dtype_for_pipelines, # Using specified dtype | |
device=device_for_pipelines | |
) | |
print(f"Text generation model ({text_generation_model_id}) loaded successfully.") | |
except Exception as e: | |
print(f"Error loading text generation model ({text_generation_model_id}): {e}") | |
text_gen_pipeline_instance = None # Explicitly set to None on failure | |
# --- Core Functions --- | |
def transcribe_audio_input(audio_filepath): | |
if not asr_pipeline_instance: | |
return "ASR model not available. Please check startup logs.", "" | |
if audio_filepath is None: | |
return "No audio file provided for transcription.", "" | |
try: | |
print(f"Transcribing: {audio_filepath}") | |
# Add chunk_length_s for handling longer audio files robustly | |
result = asr_pipeline_instance(audio_filepath, chunk_length_s=30) | |
transcribed_text = result["text"] | |
print(f"Transcription: '{transcribed_text}'") | |
return transcribed_text, transcribed_text # Return for UI and next step | |
except Exception as e: | |
print(f"Transcription error: {e}") | |
return f"Error during transcription: {str(e)}", "" | |
def generate_text_response(prompt_text): | |
if not text_gen_pipeline_instance: | |
return f"Text generation model ({text_generation_model_id}) not available. Check logs." | |
if not prompt_text or not prompt_text.strip(): | |
return "Prompt is empty. Please provide text for generation." | |
try: | |
print(f"Generating response for prompt (first 100 chars): '{prompt_text[:100]}...'") | |
# max_new_tokens is generally preferred over max_length for more control | |
generated_outputs = text_gen_pipeline_instance(prompt_text, max_new_tokens=100, num_return_sequences=1) | |
response_text = generated_outputs[0]["generated_text"] | |
print(f"Generated response: '{response_text}'") | |
return response_text | |
except Exception as e: | |
print(f"Text generation error: {e}") | |
return f"Error during text generation: {str(e)}" | |
def combined_pipeline_process(audio_filepath): | |
if audio_filepath is None: | |
return "No audio input.", "No audio input." | |
transcribed_text, _ = transcribe_audio_input(audio_filepath) | |
if not asr_pipeline_instance or "Error during transcription" in transcribed_text or not transcribed_text.strip(): | |
error_msg_for_generation = "Cannot generate response: Transcription failed or was empty." | |
if not asr_pipeline_instance: | |
error_msg_for_generation = "Cannot generate response: ASR model not loaded." | |
return transcribed_text, error_msg_for_generation | |
if not text_gen_pipeline_instance: | |
return transcribed_text, f"Cannot generate response: Text generation model ({text_generation_model_id}) not loaded." | |
final_response = generate_text_response(transcribed_text) | |
return transcribed_text, final_response | |
# --- Gradio Interface Definition --- | |
with gr.Blocks(theme=gr.themes.Soft(), title="Audio to Text Generation") as app_interface: | |
gr.Markdown( | |
""" | |
# Speech-to-Text and Text Generation Demo | |
This application uses **OpenAI Whisper Tiny** for speech recognition and **GPT-2** (as a stand-in for more complex models like LLaMA-Omni2) for text generation. | |
You can upload an audio file, have it transcribed, and then use that transcription as a prompt to generate further text. | |
**Note on LLaMA-Omni2-0.5B:** The `ICTNLP/LLaMA-Omni2-0.5B` model is a sophisticated speech-language model designed for real-time spoken chat, generating both text and speech. It requires a specific setup environment (including its own speech synthesis like CosyVoice and potentially a dedicated serving mechanism). It's not plug-and-play with a simple `transformers.pipeline` in the same way standard ASR or text-only LLMs are. Therefore, GPT-2 is used here to demonstrate the Gradio app structure. | |
""" | |
) | |
with gr.Tab("Full Pipeline: Audio -> Transcription -> Generation"): | |
gr.Markdown("### Step 1: Upload Audio -> Step 2: Transcribe -> Step 3: Generate Text") | |
input_audio_pipeline = gr.Audio(type="filepath", label="Upload Your Audio File (.wav, .mp3)") | |
submit_button_full = gr.Button("Run Full Process", variant="primary") | |
output_transcription_pipeline = gr.Textbox(label="Transcribed Text (from Whisper)", lines=5) | |
output_generation_pipeline = gr.Textbox(label=f"Generated Text (from {text_generation_model_id})", lines=7) | |
submit_button_full.click( | |
fn=combined_pipeline_process, | |
inputs=[input_audio_pipeline], | |
outputs=[output_transcription_pipeline, output_generation_pipeline] | |
) | |
with gr.Tab("Test Speech-to-Text (Whisper Tiny)"): | |
gr.Markdown("### Transcribe audio to text using Whisper Tiny.") | |
input_audio_asr = gr.Audio(type="filepath", label="Upload Audio for ASR") | |
submit_button_asr = gr.Button("Transcribe Audio", variant="secondary") | |
output_transcription_asr = gr.Textbox(label="Transcription Result", lines=10) | |
def asr_only_ui(audio_file): | |
if audio_file is None: return "Please upload an audio file." | |
# The transcribe_audio_input returns two values; we only need the first for this UI. | |
transcription, _ = transcribe_audio_input(audio_file) | |
return transcription | |
submit_button_asr.click( | |
fn=asr_only_ui, | |
inputs=[input_audio_asr], | |
outputs=[output_transcription_asr] | |
) | |
with gr.Tab(f"Test Text Generation ({text_generation_model_id})"): | |
gr.Markdown(f"### Generate text from a prompt using {text_generation_model_id}.") | |
input_text_prompt_gen = gr.Textbox(label="Your Text Prompt", placeholder="Enter text here...", lines=5) | |
submit_button_gen = gr.Button("Generate Text", variant="secondary") | |
output_generation_gen = gr.Textbox(label="Generated Text Result", lines=10) | |
submit_button_gen.click( | |
fn=generate_text_response, | |
inputs=[input_text_prompt_gen], | |
outputs=[output_generation_gen] | |
) | |
gr.Markdown("--- ") | |
gr.Markdown("### Model Loading Status (at application start):") | |
asr_load_status = "Successfully Loaded" if asr_pipeline_instance else "Failed to Load (check console logs)" | |
text_gen_load_status = "Successfully Loaded" if text_gen_pipeline_instance else "Failed to Load (check console logs)" | |
gr.Markdown(f"* **Whisper Model ({whisper_model_id}):** `{asr_load_status}`") | |
gr.Markdown(f"* **Text Generation Model ({text_generation_model_id}):** `{text_gen_load_status}`") | |
# --- Launch the Gradio App --- | |
if __name__ == "__main__": | |
print("Attempting to launch Gradio application...") | |
# share=True is good for Hugging Face Spaces. For local, it's optional. | |
# For persistent public link when running locally (requires internet & can have security implications): | |
# app_interface.launch(share=True) | |
app_interface.launch() | |
print("Gradio application launched. Check your browser or console for the URL.") |