# gradio_app.py import gradio as gr import io import os import torch from parler_tts import ParlerTTSForConditionalGeneration from transformers import AutoTokenizer, AutoModel # CHANGED: Using AutoModel as per model card import numpy as np import google.generativeai as genai import asyncio import librosa import torchaudio # Often used by models like this for audio loading/processing internally or as input type # --- Configuration --- ASR_MODEL_NAME = "ai4bharat/indic-conformer-600m-multilingual" TARGET_SAMPLE_RATE = 16000 # Model expects 16kHz TTS_MODEL_NAME = "ai4bharat/indic-parler-tts" GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "AIzaSyD6x3Yoby4eQ6QL2kaaG_Rz3fG3rh7wPB8") GEMINI_MODEL_NAME_GRADIO = "gemini-1.5-flash-latest" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # torch_dtype for ParlerTTS, Gemini etc. For ASR model, it might handle its own precision. # --- Global Model Variables --- asr_model_gradio = None # This will be the AutoModel instance gemini_model_instance_gradio = None tts_model_gradio = None tts_tokenizer_gradio = None # For ParlerTTS # --- Model Loading & API Configuration --- def load_all_resources_gradio(): global asr_model_gradio, tts_model_gradio, tts_tokenizer_gradio, gemini_model_instance_gradio print(f"Gradio: Loading resources. ASR will be on device: {DEVICE}") if asr_model_gradio is None: print(f"Gradio: Loading ASR model: {ASR_MODEL_NAME} using AutoModel") try: # Load using AutoModel as per the model card's implication asr_model_gradio = AutoModel.from_pretrained(ASR_MODEL_NAME, trust_remote_code=True) asr_model_gradio.to(DEVICE) # Move model to device # The model might handle its own precision (e.g. .half()) internally if `trust_remote_code` allows # Or you might need to call asr_model_gradio.half() if it supports it and you're on CUDA. if DEVICE == "cuda" and hasattr(asr_model_gradio, 'half'): print("Gradio: Applying .half() to ASR model.") asr_model_gradio.half() asr_model_gradio.eval() print(f"Gradio: ASR model ({ASR_MODEL_NAME}) loaded using AutoModel.") except Exception as e: print(f"Gradio: Failed to load ASR model {ASR_MODEL_NAME} using AutoModel: {e}") import traceback traceback.print_exc() asr_model_gradio = None if tts_model_gradio is None: # ParlerTTS loading print(f"Gradio: Loading IndicParler-TTS model: {TTS_MODEL_NAME}") # Ensure ParlerTTS specific tokenizer is loaded for TTS # Note: ASR model might have its own internal tokenizer/processor handled by its custom code tts_parler_tokenizer = AutoTokenizer.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True) tts_model_gradio = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True).to(DEVICE) tts_tokenizer_gradio = tts_parler_tokenizer print("Gradio: IndicParler-TTS model loaded.") if gemini_model_instance_gradio is None: # Gemini loading if not GEMINI_API_KEY: print("Gradio: GEMINI_API_KEY not found. LLM functionality via Gemini will be limited.") else: try: genai.configure(api_key=GEMINI_API_KEY) gemini_model_instance_gradio = genai.GenerativeModel(GEMINI_MODEL_NAME_GRADIO) print(f"Gradio: Gemini API configured with model: {GEMINI_MODEL_NAME_GRADIO}") except Exception as e: print(f"Gradio: Failed to configure Gemini API: {e}") gemini_model_instance_gradio = None print("Gradio: All resources loaded (or attempted).") # --- Helper Functions --- def transcribe_audio_gradio(audio_input_tuple): if asr_model_gradio is None: return f"Error: ASR model ({ASR_MODEL_NAME}) not loaded." if audio_input_tuple is None: print("Gradio: No audio provided to transcribe_audio_gradio.") return "No audio provided." sample_rate, audio_numpy = audio_input_tuple if audio_numpy is None or audio_numpy.size == 0: print("Gradio: Audio numpy array is empty.") return "Empty audio received." # Ensure audio is mono float32, which is a common expectation if audio_numpy.ndim > 1: if audio_numpy.shape[0] == 2 and audio_numpy.ndim == 2: audio_numpy = librosa.to_mono(audio_numpy) elif audio_numpy.shape[1] == 2 and audio_numpy.ndim == 2: audio_numpy = np.mean(audio_numpy, axis=1) if audio_numpy.dtype != np.float32: if np.issubdtype(audio_numpy.dtype, np.integer): audio_numpy = audio_numpy.astype(np.float32) / np.iinfo(audio_numpy.dtype).max else: audio_numpy = audio_numpy.astype(np.float32) # Resample to TARGET_SAMPLE_RATE (16kHz) if sample_rate != TARGET_SAMPLE_RATE: print(f"Gradio: Resampling audio from {sample_rate} Hz to {TARGET_SAMPLE_RATE} Hz.") try: audio_numpy = librosa.resample(y=audio_numpy, orig_sr=sample_rate, target_sr=TARGET_SAMPLE_RATE) # After resampling, the audio_numpy is at TARGET_SAMPLE_RATE except Exception as e: print(f"Gradio: Error during resampling: {e}") return f"Error during audio resampling: {str(e)}" try: print(f"Gradio: Preparing to transcribe with {ASR_MODEL_NAME}. Input audio shape: {audio_numpy.shape}") # The model card example `model(wav, "hi", "ctc")` implies it might take a waveform tensor. # We have a numpy array. We need to convert it to a PyTorch tensor. # The model card uses torchaudio.load which returns a tensor. # Let's convert our numpy array to a tensor and ensure it's on the correct device. # Ensure the audio_numpy is 1D as expected by many ASR models for a single channel if audio_numpy.ndim > 1: audio_numpy = audio_numpy.squeeze() # Attempt to remove singleton dimensions if audio_numpy.ndim > 1 : # If still more than 1D, problem print(f"Gradio: Audio numpy array for ASR has unexpected dimensions after processing: {audio_numpy.shape}") return "Error: Audio processing resulted in unexpected dimensions." wav_tensor = torch.from_numpy(audio_numpy).to(DEVICE) # The model might expect a batch dimension, e.g., [1, num_samples] if wav_tensor.ndim == 1: wav_tensor = wav_tensor.unsqueeze(0) print(f"Gradio: Transcribing with {ASR_MODEL_NAME} using CTC. Input tensor shape: {wav_tensor.shape}") # Perform ASR with CTC decoding (you can choose "rnnt" if preferred and supported) # The language code "hi" is for Hindi. You might want to make this configurable # or see if the model supports language auto-detection if you pass None or omit it. # For now, assuming "hi" or that the model handles mixed language if lang_id is not strictly enforced. # The model card doesn't specify if language ID is optional or how auto-detection works. # Let's try "auto" or a common language like "en" or "hi" to start. # The model card indicates training on 22 languages, so it's multilingual. # If language_id is required, you'll need to provide it. # Let's assume for now we try with a common Indian language or let the model try to auto-detect if "auto" or None is valid. # The snippet "model(wav, "hi", "ctc")" is specific. # The `model()` call is synchronous. Gradio handles this in a thread. with torch.no_grad(): # Good practice for inference transcription_result = asr_model_gradio(wav_tensor, "hi", "ctc") # Using lang_id="hi" and strategy="ctc" as per example # The output format needs to be checked. The model card implies it's the transcribed string directly. # It might be a list of transcriptions if batching occurs, or a dict. if isinstance(transcription_result, list) and len(transcription_result) > 0: transcribed_text = transcription_result[0] # Assuming first result for non-batched input elif isinstance(transcription_result, str): transcribed_text = transcription_result else: print(f"Gradio: Unexpected ASR result format: {type(transcription_result)}, value: {transcription_result}") transcribed_text = "ASR result format not recognized." transcribed_text = transcribed_text.strip() print(f"Gradio: Transcription ({ASR_MODEL_NAME}, CTC): {transcribed_text}") return transcribed_text if transcribed_text else "Transcription resulted in empty text." except Exception as e: print(f"Gradio: Error during {ASR_MODEL_NAME} transcription (AutoModel callable): {e}") import traceback traceback.print_exc() return f"Error during transcription ({ASR_MODEL_NAME}): {str(e)}" # ... (Gemini LLM and TTS functions remain the same) ... def generate_gemini_response_gradio(text_input: str): if not gemini_model_instance_gradio: return "Error: Gemini LLM not configured or API key missing." if not isinstance(text_input, str) or not text_input.strip() or text_input.startswith("Error:") or "No audio provided" in text_input or "Transcription resulted in empty text" in text_input or "Empty audio received" in text_input or "ASR result format not recognized" in text_input: print(f"Gradio: Invalid input to Gemini: '{text_input}'. Skipping LLM response.") return "LLM (Gemini) skipped due to transcription issue or no input." try: print(f"Gradio: Sending to Gemini: '{text_input}'") full_prompt = f"User: {text_input}\nAssistant:" response = gemini_model_instance_gradio.generate_content(full_prompt) response_text = "" if response.candidates and response.candidates[0].content.parts: response_text = response.candidates[0].content.parts[0].text.strip() else: feedback_info = "" if hasattr(response, 'prompt_feedback') and response.prompt_feedback: feedback_info = f" Feedback: {response.prompt_feedback}" print(f"Gradio: Gemini response did not contain expected content.{feedback_info}") response_text = f"I'm sorry, I couldn't generate a response for that (Gemini).{feedback_info}" print(f"Gradio: Gemini LLM Response: {response_text}") return response_text if response_text else "Gemini LLM generated an empty response." except Exception as e: print(f"Gradio: Error during Gemini LLM generation: {e}") import traceback traceback.print_exc() return f"Error during Gemini LLM generation: {str(e)}" def synthesize_speech_gradio(text_input: str, description: str = "A clear, female voice speaking in English."): if tts_model_gradio is None or tts_tokenizer_gradio is None: return "Error: TTS model or its tokenizer not loaded." if not isinstance(text_input, str) or not text_input.strip() or text_input.startswith("Error:") or "LLM skipped" in text_input or "generated an empty response" in text_input or "not configured" in text_input or "ASR result format not recognized" in text_input : print(f"Gradio: Invalid input to TTS: '{text_input}'. Skipping synthesis.") return "TTS skipped due to LLM issue or no input." try: print(f"Gradio: Synthesizing speech for: '{text_input}'") description_tokenized = tts_tokenizer_gradio(description, return_tensors="pt", padding=True, truncation=True, max_length=128) description_ids = description_tokenized.input_ids.to(DEVICE) description_attention_mask = description_tokenized.attention_mask.to(DEVICE) prompt_tokenized = tts_tokenizer_gradio(text_input, return_tensors="pt", padding=True, truncation=True, max_length=512) prompt_ids = prompt_tokenized.input_ids.to(DEVICE) if prompt_ids.shape[-1] == 0: # Check if tokenized prompt is empty print(f"Gradio: Tokenized prompt for TTS is empty. Text was: '{text_input}'. Skipping synthesis.") return "TTS skipped: Input text resulted in empty tokens." generation = tts_model_gradio.generate( input_ids=description_ids, attention_mask=description_attention_mask, prompt_input_ids=prompt_ids, do_sample=True, temperature=0.7, top_k=50, top_p=0.95 ).cpu().numpy().squeeze() sampling_rate = tts_model_gradio.config.sampling_rate print(f"Gradio: Speech synthesized. Array shape: {generation.shape}, Sample rate: {sampling_rate}") return (sampling_rate, generation) except Exception as e: print(f"Gradio: Error during speech synthesis: {e}") import traceback traceback.print_exc() if "You need to specify either `text` or `text_target`" in str(e): return "Error in TTS: Model requires 'text' or 'text_target'. Input might be too short or problematic." return f"Error during speech synthesis: {str(e)}" # --- Gradio Interface Definition --- load_all_resources_gradio() def full_pipeline_gradio(audio_input): transcribed_text_output = transcribe_audio_gradio(audio_input) print(f"DEBUG full_pipeline_gradio - Step 1 (Transcription): '{transcribed_text_output}' (type: {type(transcribed_text_output)})") llm_response_text_output = generate_gemini_response_gradio(transcribed_text_output) print(f"DEBUG full_pipeline_gradio - Step 2 (LLM Response): '{llm_response_text_output}' (type: {type(llm_response_text_output)})") tts_synthesis_result = synthesize_speech_gradio(llm_response_text_output) final_audio_output = None if isinstance(tts_synthesis_result, tuple) and len(tts_synthesis_result) == 2 and isinstance(tts_synthesis_result[1], np.ndarray): final_audio_output = tts_synthesis_result print(f"DEBUG full_pipeline_gradio - Step 3 (TTS Success): Audio tuple with shape {tts_synthesis_result[1].shape if isinstance(tts_synthesis_result[1], np.ndarray) else 'N/A'}") else: error_message_from_tts = str(tts_synthesis_result) if isinstance(tts_synthesis_result, str) else "TTS synthesis failed or returned unexpected type" print(f"DEBUG full_pipeline_gradio - Step 3 (TTS Failed/Non-audio): {error_message_from_tts}. Providing silent audio.") # Append TTS error to LLM text only if LLM text was valid if llm_response_text_output and not llm_response_text_output.startswith("Error:") and "LLM skipped" not in llm_response_text_output and "ASR result format not recognized" not in llm_response_text_output: llm_response_text_output = f"{llm_response_text_output} | (TTS Problem: {error_message_from_tts})" elif not llm_response_text_output or llm_response_text_output.startswith("Error:") or "LLM skipped" in llm_response_text_output or "ASR result format not recognized" in llm_response_text_output: # If LLM already had an error, just keep that error, maybe note TTS also had an issue llm_response_text_output = f"{llm_response_text_output} (TTS also had an issue: {error_message_from_tts})" default_sample_rate = tts_model_gradio.config.sampling_rate if tts_model_gradio and hasattr(tts_model_gradio, 'config') else TARGET_SAMPLE_RATE final_audio_output = (default_sample_rate, np.array([0.0], dtype=np.float32)) print(f"DEBUG full_pipeline_gradio - Step 3 (TTS Fallback): Silent audio tuple") print(f"DEBUG full_pipeline_gradio - RETURNING: Transcription='{transcribed_text_output}', LLM_Text='{llm_response_text_output}', Audio_Type={type(final_audio_output)}") return transcribed_text_output, llm_response_text_output, final_audio_output with gr.Blocks(title="Conversational AI Demo") as demo: gr.Markdown("# Conversational AI Demo (STT -> Gemini LLM -> TTS)") with gr.Row(): audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Speak Here") process_button = gr.Button("Process Audio") with gr.Accordion("Outputs", open=True): transcription_out = gr.Textbox(label="You Said (Transcription)", lines=2) llm_response_out = gr.Textbox(label="Gemini Assistant Says (Text)", lines=5) audio_out = gr.Audio(label="Assistant Says (Audio)") process_button.click( fn=full_pipeline_gradio, inputs=[audio_in], outputs=[transcription_out, llm_response_out, audio_out] ) gr.Markdown("---") gr.Markdown("### How to Use:") gr.Markdown("1. Ensure your `GEMINI_API_KEY` environment variable is set.") gr.Markdown("2. Click into the 'Speak Here' box and record your audio.") gr.Markdown("3. Click the 'Process Audio' button.") gr.Markdown("4. View the transcription, Gemini's text response, and listen to the audio response.") if __name__ == "__main__": demo.launch(share=False)