import gradio as gr import torch from transformers import pipeline import os # os is imported but not used. Consider removing if not needed. # --- Model Configuration --- whisper_model_id = "openai/whisper-tiny" # Using gpt2 as a placeholder due to LLaMA-Omni2-0.5B's complex setup needs. # LLaMA-Omni2-0.5B (ICTNLP/LLaMA-Omni2-0.5B) is a speech-language model # requiring specific dependencies (e.g., CosyVoice) and often its own serving infrastructure. # It's not typically loaded via a simple transformers.pipeline for text generation alone. text_generation_model_id = "gpt2" # --- Device Configuration --- if torch.cuda.is_available(): device_for_pipelines = 0 # Use the first GPU for Hugging Face pipelines torch_device = "cuda:0" # PyTorch device string # For models that support it and where precision is not critical, float16 can save memory/speed up. # However, Whisper models are often more robust with float32 for pipeline usage unless memory is very constrained. # GPT-2 also generally runs fine on float32 and doesn't strictly need float16 for basic use. dtype_for_pipelines = torch.float16 # or torch.float32 depending on model/GPU else: device_for_pipelines = -1 # Use CPU for Hugging Face pipelines torch_device = "cpu" dtype_for_pipelines = torch.float32 print(f"Using device: {torch_device} for model loading.") print(f"Pipelines will use device_id: {device_for_pipelines} and dtype: {dtype_for_pipelines}") # --- Load Speech-to-Text (ASR) Pipeline --- asr_pipeline_instance = None try: print(f"Loading ASR model: {whisper_model_id}...") asr_pipeline_instance = pipeline( "automatic-speech-recognition", model=whisper_model_id, torch_dtype=dtype_for_pipelines, # Using specified dtype device=device_for_pipelines ) print(f"ASR model ({whisper_model_id}) loaded successfully.") except Exception as e: print(f"Error loading ASR model ({whisper_model_id}): {e}") asr_pipeline_instance = None # Explicitly set to None on failure # --- Load Text Generation Pipeline --- text_gen_pipeline_instance = None try: print(f"Loading text generation model: {text_generation_model_id}...") text_gen_pipeline_instance = pipeline( "text-generation", model=text_generation_model_id, torch_dtype=dtype_for_pipelines, # Using specified dtype device=device_for_pipelines ) print(f"Text generation model ({text_generation_model_id}) loaded successfully.") except Exception as e: print(f"Error loading text generation model ({text_generation_model_id}): {e}") text_gen_pipeline_instance = None # Explicitly set to None on failure # --- Core Functions --- def transcribe_audio_input(audio_filepath): if not asr_pipeline_instance: return "ASR model not available. Please check startup logs.", "" if audio_filepath is None: return "No audio file provided for transcription.", "" try: print(f"Transcribing: {audio_filepath}") # Add chunk_length_s for handling longer audio files robustly result = asr_pipeline_instance(audio_filepath, chunk_length_s=30) transcribed_text = result["text"] print(f"Transcription: '{transcribed_text}'") return transcribed_text, transcribed_text # Return for UI and next step except Exception as e: print(f"Transcription error: {e}") return f"Error during transcription: {str(e)}", "" def generate_text_response(prompt_text): if not text_gen_pipeline_instance: return f"Text generation model ({text_generation_model_id}) not available. Check logs." if not prompt_text or not prompt_text.strip(): return "Prompt is empty. Please provide text for generation." try: print(f"Generating response for prompt (first 100 chars): '{prompt_text[:100]}...'") # max_new_tokens is generally preferred over max_length for more control generated_outputs = text_gen_pipeline_instance(prompt_text, max_new_tokens=100, num_return_sequences=1) response_text = generated_outputs[0]["generated_text"] print(f"Generated response: '{response_text}'") return response_text except Exception as e: print(f"Text generation error: {e}") return f"Error during text generation: {str(e)}" def combined_pipeline_process(audio_filepath): if audio_filepath is None: return "No audio input.", "No audio input." transcribed_text, _ = transcribe_audio_input(audio_filepath) if not asr_pipeline_instance or "Error during transcription" in transcribed_text or not transcribed_text.strip(): error_msg_for_generation = "Cannot generate response: Transcription failed or was empty." if not asr_pipeline_instance: error_msg_for_generation = "Cannot generate response: ASR model not loaded." return transcribed_text, error_msg_for_generation if not text_gen_pipeline_instance: return transcribed_text, f"Cannot generate response: Text generation model ({text_generation_model_id}) not loaded." final_response = generate_text_response(transcribed_text) return transcribed_text, final_response # --- Gradio Interface Definition --- with gr.Blocks(theme=gr.themes.Soft(), title="Audio to Text Generation") as app_interface: gr.Markdown( """ # Speech-to-Text and Text Generation Demo This application uses **OpenAI Whisper Tiny** for speech recognition and **GPT-2** (as a stand-in for more complex models like LLaMA-Omni2) for text generation. You can upload an audio file, have it transcribed, and then use that transcription as a prompt to generate further text. **Note on LLaMA-Omni2-0.5B:** The `ICTNLP/LLaMA-Omni2-0.5B` model is a sophisticated speech-language model designed for real-time spoken chat, generating both text and speech. It requires a specific setup environment (including its own speech synthesis like CosyVoice and potentially a dedicated serving mechanism). It's not plug-and-play with a simple `transformers.pipeline` in the same way standard ASR or text-only LLMs are. Therefore, GPT-2 is used here to demonstrate the Gradio app structure. """ ) with gr.Tab("Full Pipeline: Audio -> Transcription -> Generation"): gr.Markdown("### Step 1: Upload Audio -> Step 2: Transcribe -> Step 3: Generate Text") input_audio_pipeline = gr.Audio(type="filepath", label="Upload Your Audio File (.wav, .mp3)") submit_button_full = gr.Button("Run Full Process", variant="primary") output_transcription_pipeline = gr.Textbox(label="Transcribed Text (from Whisper)", lines=5) output_generation_pipeline = gr.Textbox(label=f"Generated Text (from {text_generation_model_id})", lines=7) submit_button_full.click( fn=combined_pipeline_process, inputs=[input_audio_pipeline], outputs=[output_transcription_pipeline, output_generation_pipeline] ) with gr.Tab("Test Speech-to-Text (Whisper Tiny)"): gr.Markdown("### Transcribe audio to text using Whisper Tiny.") input_audio_asr = gr.Audio(type="filepath", label="Upload Audio for ASR") submit_button_asr = gr.Button("Transcribe Audio", variant="secondary") output_transcription_asr = gr.Textbox(label="Transcription Result", lines=10) def asr_only_ui(audio_file): if audio_file is None: return "Please upload an audio file." # The transcribe_audio_input returns two values; we only need the first for this UI. transcription, _ = transcribe_audio_input(audio_file) return transcription submit_button_asr.click( fn=asr_only_ui, inputs=[input_audio_asr], outputs=[output_transcription_asr] ) with gr.Tab(f"Test Text Generation ({text_generation_model_id})"): gr.Markdown(f"### Generate text from a prompt using {text_generation_model_id}.") input_text_prompt_gen = gr.Textbox(label="Your Text Prompt", placeholder="Enter text here...", lines=5) submit_button_gen = gr.Button("Generate Text", variant="secondary") output_generation_gen = gr.Textbox(label="Generated Text Result", lines=10) submit_button_gen.click( fn=generate_text_response, inputs=[input_text_prompt_gen], outputs=[output_generation_gen] ) gr.Markdown("--- ") gr.Markdown("### Model Loading Status (at application start):") asr_load_status = "Successfully Loaded" if asr_pipeline_instance else "Failed to Load (check console logs)" text_gen_load_status = "Successfully Loaded" if text_gen_pipeline_instance else "Failed to Load (check console logs)" gr.Markdown(f"* **Whisper Model ({whisper_model_id}):** `{asr_load_status}`") gr.Markdown(f"* **Text Generation Model ({text_generation_model_id}):** `{text_gen_load_status}`") # --- Launch the Gradio App --- if __name__ == "__main__": print("Attempting to launch Gradio application...") # share=True is good for Hugging Face Spaces. For local, it's optional. # For persistent public link when running locally (requires internet & can have security implications): # app_interface.launch(share=True) app_interface.launch() print("Gradio application launched. Check your browser or console for the URL.")