import gradio as gr import torch from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer import os import warnings import importlib import sys import subprocess # Check if we can import LLaMA-Omni2's modules try_native_modules = True native_llama_omni_available = False native_modules_error = None if try_native_modules: try: # Try importing LLaMA-Omni2 specific modules using subprocess to avoid crashing if imports fail print("Checking for LLaMA-Omni2 native modules...") module_check_result = subprocess.run( [sys.executable, "-c", "import llama_omni2; print('LLaMA-Omni2 modules found!')"], capture_output=True, text=True ) if "LLaMA-Omni2 modules found!" in module_check_result.stdout: print("LLaMA-Omni2 native modules are available!") native_llama_omni_available = True else: print(f"LLaMA-Omni2 native modules not found: {module_check_result.stderr}") native_modules_error = module_check_result.stderr except Exception as e: print(f"Error checking for LLaMA-Omni2 native modules: {e}") native_modules_error = str(e) # --- Model Configuration --- whisper_model_id = "openai/whisper-tiny" llama_omni_model_id = "ICTNLP/LLaMA-Omni2-0.5B" # Primary model we'll try to load fallback_model_id = "gpt2" # Fallback if LLaMA-Omni2 fails to load # --- Device Configuration --- if torch.cuda.is_available(): device_for_pipelines = 0 # Use the first GPU for Hugging Face pipelines torch_device = "cuda:0" # PyTorch device string dtype_for_pipelines = torch.float16 else: device_for_pipelines = -1 # Use CPU for Hugging Face pipelines torch_device = "cpu" dtype_for_pipelines = torch.float32 print(f"Using device: {torch_device} for model loading.") print(f"Pipelines will use device_id: {device_for_pipelines} and dtype: {dtype_for_pipelines}") # --- Load Speech-to-Text (ASR) Pipeline --- asr_pipeline_instance = None try: print(f"Loading ASR model: {whisper_model_id}...") asr_pipeline_instance = pipeline( "automatic-speech-recognition", model=whisper_model_id, torch_dtype=dtype_for_pipelines, device=device_for_pipelines ) print(f"ASR model ({whisper_model_id}) loaded successfully.") except Exception as e: print(f"Error loading ASR model ({whisper_model_id}): {e}") asr_pipeline_instance = None # --- Load Text Generation Model --- text_gen_pipeline_instance = None text_generation_model_id = None # Will be set to the model that successfully loads llama_omni_native_module = None # Will hold the native LLaMA-Omni2 module if loaded # Try native LLaMA-Omni2 module first if available if native_llama_omni_available: try: print("Attempting to load LLaMA-Omni2 using native modules...") # Import the required modules import llama_omni2 from llama_omni2.model import Model as LLamaOmniModel # Load the model llama_omni_native_module = LLamaOmniModel.from_pretrained(llama_omni_model_id) text_generation_model_id = llama_omni_model_id print(f"LLaMA-Omni2 native module loaded successfully: {type(llama_omni_native_module)}") except Exception as e: print(f"Error loading native LLaMA-Omni2 module: {e}") llama_omni_native_module = None # If native module failed, try loading using transformers if llama_omni_native_module is None and text_generation_model_id is None: try: print(f"Attempting to load LLaMA-Omni2 using transformers: {llama_omni_model_id}...") # LLaMA models often require specific loading configurations tokenizer = AutoTokenizer.from_pretrained(llama_omni_model_id, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( llama_omni_model_id, torch_dtype=dtype_for_pipelines, trust_remote_code=True, device_map="auto" if torch.cuda.is_available() else None ) text_gen_pipeline_instance = pipeline( "text-generation", model=model, tokenizer=tokenizer, torch_dtype=dtype_for_pipelines, device=device_for_pipelines if not torch.cuda.is_available() else None ) text_generation_model_id = llama_omni_model_id print(f"LLaMA-Omni2 model ({llama_omni_model_id}) loaded successfully via transformers.") except Exception as e: warnings.warn(f"Error loading LLaMA-Omni2 model: {e}\nFalling back to {fallback_model_id}") print(f"Error loading LLaMA-Omni2 model via transformers: {e}") print(f"Falling back to {fallback_model_id}") # Fall back to GPT-2 if LLaMA-Omni2 fails to load both ways if text_generation_model_id is None: try: print(f"Loading fallback text generation model: {fallback_model_id}...") text_gen_pipeline_instance = pipeline( "text-generation", model=fallback_model_id, torch_dtype=dtype_for_pipelines, device=device_for_pipelines ) text_generation_model_id = fallback_model_id print(f"Fallback model ({fallback_model_id}) loaded successfully.") except Exception as e: print(f"Error loading fallback model ({fallback_model_id}): {e}") text_gen_pipeline_instance = None # --- Core Functions --- def transcribe_audio_input(audio_filepath): if not asr_pipeline_instance: return "ASR model not available. Please check startup logs.", "" if audio_filepath is None: return "No audio file provided for transcription.", "" try: print(f"Transcribing: {audio_filepath}") result = asr_pipeline_instance(audio_filepath, chunk_length_s=30) transcribed_text = result["text"] print(f"Transcription: '{transcribed_text}'") return transcribed_text, transcribed_text except Exception as e: print(f"Transcription error: {e}") return f"Error during transcription: {str(e)}", "" def generate_text_response(prompt_text): # If we have a native LLaMA-Omni2 module, use it if llama_omni_native_module is not None: if not prompt_text or not prompt_text.strip(): return "Prompt is empty. Please provide text for generation." try: print(f"Generating response with native LLaMA-Omni2 for prompt: '{prompt_text[:100]}...'") # Using the native module's interface for text generation response = llama_omni_native_module.generate(prompt_text, max_length=150) print(f"Generated response: '{response}'") return response except Exception as e: print(f"Error using native LLaMA-Omni2 generation: {e}") return f"Error during native LLaMA-Omni2 text generation: {str(e)}" # Otherwise use the transformers pipeline if not text_gen_pipeline_instance: return f"Text generation model not available. Check logs." if not prompt_text or not prompt_text.strip(): return "Prompt is empty. Please provide text for generation." try: print(f"Generating response for prompt (first 100 chars): '{prompt_text[:100]}...'") # Different generation parameters based on model if text_generation_model_id == llama_omni_model_id: # Parameters optimized for LLaMA-Omni2 generated_outputs = text_gen_pipeline_instance( prompt_text, max_new_tokens=150, do_sample=True, temperature=0.7, top_p=0.9, num_return_sequences=1 ) else: # Parameters for fallback model generated_outputs = text_gen_pipeline_instance( prompt_text, max_new_tokens=100, num_return_sequences=1 ) response_text = generated_outputs[0]["generated_text"] print(f"Generated response: '{response_text}'") return response_text except Exception as e: print(f"Text generation error: {e}") return f"Error during text generation: {str(e)}" def combined_pipeline_process(audio_filepath): if audio_filepath is None: return "No audio input.", "No audio input." transcribed_text, _ = transcribe_audio_input(audio_filepath) if not asr_pipeline_instance or "Error during transcription" in transcribed_text or not transcribed_text.strip(): error_msg_for_generation = "Cannot generate response: Transcription failed or was empty." if not asr_pipeline_instance: error_msg_for_generation = "Cannot generate response: ASR model not loaded." return transcribed_text, error_msg_for_generation if not text_gen_pipeline_instance and llama_omni_native_module is None: return transcribed_text, f"Cannot generate response: No text generation model available." final_response = generate_text_response(transcribed_text) return transcribed_text, final_response # Determine model status for UI if llama_omni_native_module is not None: llama_model_status = "Native LLaMA-Omni2 module loaded successfully" using_model = "LLaMA-Omni2-0.5B (native modules)" elif text_generation_model_id == llama_omni_model_id: llama_model_status = "LLaMA-Omni2 loaded via transformers" using_model = "LLaMA-Omni2-0.5B (via transformers)" elif text_generation_model_id == fallback_model_id: llama_model_status = "Failed to load - Using GPT-2 as fallback" using_model = "GPT-2 (fallback model)" else: llama_model_status = "Failed to load any text generation model" using_model = "No model available" # --- Gradio Interface Definition --- with gr.Blocks(theme=gr.themes.Soft(), title="Whisper + LLaMA-Omni2 Demo") as app_interface: gr.Markdown( f""" # Speech-to-Text and Text Generation Demo This application uses **OpenAI Whisper Tiny** for speech recognition and attempts to use **LLaMA-Omni2-0.5B** for text generation. If LLaMA-Omni2 cannot be loaded, it falls back to GPT-2. **Currently using:** {using_model} Upload an audio file to transcribe it. The transcribed text will then be used as a prompt for the text generation model. """ ) with gr.Tab("Full Pipeline: Audio -> Transcription -> Generation"): gr.Markdown("### Step 1: Upload Audio -> Step 2: Transcribe -> Step 3: Generate Text") input_audio_pipeline = gr.Audio(type="filepath", label="Upload Your Audio File (.wav, .mp3)") submit_button_full = gr.Button("Run Full Process", variant="primary") output_transcription_pipeline = gr.Textbox(label="Transcribed Text (from Whisper)", lines=5) model_label = f"Generated Text (from {using_model})" output_generation_pipeline = gr.Textbox(label=model_label, lines=7) submit_button_full.click( fn=combined_pipeline_process, inputs=[input_audio_pipeline], outputs=[output_transcription_pipeline, output_generation_pipeline] ) with gr.Tab("Test Speech-to-Text (Whisper Tiny)"): gr.Markdown("### Transcribe audio to text using Whisper Tiny.") input_audio_asr = gr.Audio(type="filepath", label="Upload Audio for ASR") submit_button_asr = gr.Button("Transcribe Audio", variant="secondary") output_transcription_asr = gr.Textbox(label="Transcription Result", lines=10) def asr_only_ui(audio_file): if audio_file is None: return "Please upload an audio file." transcription, _ = transcribe_audio_input(audio_file) return transcription submit_button_asr.click( fn=asr_only_ui, inputs=[input_audio_asr], outputs=[output_transcription_asr] ) with gr.Tab(f"Test Text Generation"): model_name_gen = using_model gr.Markdown(f"### Generate text from a prompt using {model_name_gen}.") input_text_prompt_gen = gr.Textbox(label="Your Text Prompt", placeholder="Enter text here...", lines=5) submit_button_gen = gr.Button("Generate Text", variant="secondary") output_generation_gen = gr.Textbox(label="Generated Text Result", lines=10) submit_button_gen.click( fn=generate_text_response, inputs=[input_text_prompt_gen], outputs=[output_generation_gen] ) gr.Markdown("--- ") gr.Markdown("### Model Loading Status (at application start):") asr_load_status = "Successfully Loaded" if asr_pipeline_instance else "Failed to Load (check console logs)" gr.Markdown(f"* **Whisper Model ({whisper_model_id}):** `{asr_load_status}`") gr.Markdown(f"* **LLaMA-Omni2 Model ({llama_omni_model_id}):** `{llama_model_status}`") if native_llama_omni_available: gr.Markdown("* **LLaMA-Omni2 Native Modules:** `Available`") else: native_error = f": {native_modules_error}" if native_modules_error else "" gr.Markdown(f"* **LLaMA-Omni2 Native Modules:** `Not Available{native_error}`") if using_model.startswith("GPT-2"): gr.Markdown( """ **Note about LLaMA-Omni2-0.5B:** This model has complex dependencies and requires a specific setup environment. The system attempted to load it but fell back to GPT-2. For full functionality with LLaMA-Omni2, you should: 1. Clone the [LLaMA-Omni2 repository](https://github.com/ictnlp/LLaMA-Omni2) 2. Install the required dependencies including CosyVoice 2 3. Download the Whisper-large-v3 model and flow-matching model and vocoder of CosyVoice 2 4. Set up the controller, model worker, and web server as described in the repository Note that LLaMA-Omni2 is designed for generating both text and speech responses simultaneously. For the full experience with speech synthesis, you need the complete setup. """ ) # --- Launch the Gradio App --- if __name__ == "__main__": print("Launching Gradio demo...") try: app_interface.launch(share=True) except Exception as e: print(f"Error launching with share=True: {e}") print("Trying to launch without sharing...") app_interface.launch()