wishitwerethe90s commited on
Commit
c2ac364
·
verified ·
1 Parent(s): 3c4b548

Upload folder using huggingface_hub

Browse files
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
README.md CHANGED
@@ -1,12 +1,112 @@
1
  ---
2
- title: Voice Assistant
3
- emoji: 💻
4
- colorFrom: indigo
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.30.0
8
- app_file: app.py
9
- pinned: false
10
  ---
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: voice-assistant
3
+ app_file: gradio_app.py
 
 
4
  sdk: gradio
5
+ sdk_version: 5.29.1
 
 
6
  ---
7
+ # Real-time Conversational AI Chatbot Backend
8
 
9
+ This project implements a Python-based backend for a real-time conversational AI chatbot. It features Speech-to-Text (STT), Language Model (LLM) processing via Google's Gemini API, and streaming Text-to-Speech (TTS) capabilities, all orchestrated through a FastAPI web server with WebSocket support for interactive conversations.
10
+
11
+ ## Core Features
12
+
13
+ - **Speech-to-Text (STT):** Utilizes OpenAI's Whisper model to transcribe user's spoken audio into text.
14
+ - **Language Model (LLM):** Integrates with Google's Gemini API (e.g., `gemini-1.5-flash-latest`) for generating intelligent and contextual responses.
15
+ - **Text-to-Speech (TTS) with Streaming:** Employs AI4Bharat's IndicParler-TTS model (via `parler-tts` library) with `ParlerTTSStreamer` to convert the LLM's text response into audible speech, streamed chunk by chunk for faster time-to-first-audio.
16
+ - **Real-time Interaction:** A WebSocket endpoint (`/ws/conversation`) manages the live, bidirectional flow of audio and text data between the client and server.
17
+ - **Component Testing:** Includes individual HTTP RESTful endpoints for testing STT, LLM, and TTS functionalities separately.
18
+ - **Basic Client Demo:** Provides a simple HTML/JavaScript client served at the root (`/`) for demonstrating the WebSocket conversation flow.
19
+
20
+ ## Technologies Used
21
+
22
+ - **Backend Framework:** FastAPI
23
+ - **ASR (STT):** OpenAI Whisper
24
+ - **LLM:** Google Gemini API (via `google-generativeai` SDK)
25
+ - **TTS:** AI4Bharat IndicParler-TTS (via `parler-tts` and `transformers`)
26
+ - **Audio Processing:** `soundfile`, `librosa`
27
+ - **Async & Concurrency:** `asyncio`, `threading` (for ParlerTTSStreamer)
28
+ - **ML/DL:** PyTorch
29
+ - **Web Server:** Uvicorn
30
+
31
+ ## Setup and Installation
32
+
33
+ 1. **Clone the Repository (if applicable)**
34
+
35
+ ```bash
36
+ git clone <your-repo-url>
37
+ cd <your-repo-name>
38
+ ```
39
+
40
+ 2. **Create a Python Virtual Environment**
41
+
42
+ - Using `venv`:
43
+ ```bash
44
+ python -m venv venv
45
+ source venv/bin/activate # On Windows: venv\Scripts\activate
46
+ ```
47
+ - Or using `conda`:
48
+ ```bash
49
+ conda create -n voicebot_env python=3.10 # Or your preferred Python 3.9+
50
+ conda activate voicebot_env
51
+ ```
52
+
53
+ 3. **Install Dependencies**
54
+
55
+ ```bash
56
+ pip install -r requirements.txt
57
+ ```
58
+
59
+ Ensure you have `ffmpeg` installed on your system, as Whisper requires it.
60
+ (e.g., `sudo apt update && sudo apt install ffmpeg` on Debian/Ubuntu)
61
+
62
+ 4. **Set Environment Variables:**
63
+ - **Gemini API Key:** Obtain an API key from [Google AI Studio](https://aistudio.google.com/). Set it as an environment variable:
64
+ ```bash
65
+ export GEMINI_API_KEY="YOUR_ACTUAL_GEMINI_API_KEY"
66
+ ```
67
+ (For Windows PowerShell: `$env:GEMINI_API_KEY="YOUR_ACTUAL_GEMINI_API_KEY"`)
68
+ - **(Optional) Whisper Model Size:**
69
+ ```bash
70
+ export WHISPER_MODEL_SIZE="base" # (e.g., tiny, base, small, medium, large)
71
+ ```
72
+ Defaults to "base" if not set.
73
+
74
+ ### HTTP RESTful Endpoints
75
+
76
+ These are standard FastAPI path operations for testing individual components:
77
+
78
+ - **`POST /api/stt`**: Upload an audio file to get its transcription.
79
+ - **`POST /api/llm`**: Send text in a JSON payload to get a response from Gemini.
80
+ - **`POST /api/tts`**: Send text in a JSON payload to get synthesized audio (non-streaming for this HTTP endpoint, returns base64 encoded WAV).
81
+
82
+ ### WebSocket Endpoint: `/ws/conversation`
83
+
84
+ This is the primary endpoint for real-time, bidirectional conversational interaction:
85
+
86
+ - `@app.websocket("/ws/conversation")` defines the WebSocket route.
87
+ - **Connection Handling:** Accepts new WebSocket connections.
88
+ - **Main Interaction Loop:**
89
+ 1. **Receive Audio:** Waits to receive audio data (bytes) from the client (`await websocket.receive_bytes()`).
90
+ 2. **STT:** Calls `transcribe_audio_bytes()` to get text from the user's audio. Sends `USER_TRANSCRIPT: <text>` back to the client.
91
+ 3. **LLM:** Calls `generate_gemini_response()` with the transcribed text. Sends `ASSISTANT_RESPONSE_TEXT: <text>` back to the client.
92
+ 4. **Streaming TTS:**
93
+ - Sends a `TTS_STREAM_START: {<audio_params>}` message to the client, informing it about the sample rate, channels, and bit depth of the upcoming audio stream.
94
+ - Iterates through the `synthesize_speech_streaming()` asynchronous generator.
95
+ - For each `audio_chunk_bytes` yielded, it sends these raw audio bytes to the client using `await websocket.send_bytes()`.
96
+ - If `websocket.send_bytes()` fails (e.g., client disconnected), the loop breaks, and the `cancellation_event` is set to signal the TTS thread.
97
+ - After the stream is complete (or cancelled), it sends a `TTS_STREAM_END` message.
98
+ - **Error Handling:** Includes `try...except WebSocketDisconnect` to handle client disconnections gracefully and a general exception handler.
99
+ - **Cleanup:** The `finally` block ensures the `cancellation_event` for TTS is set and attempts to close the WebSocket.
100
+
101
+ ## How to Run
102
+
103
+ 1. Ensure all setup steps (environment, dependencies, API key) are complete.
104
+ 2. Execute the script:
105
+ ```bash
106
+ python main.py
107
+ ```
108
+ Or, for development with auto-reload:
109
+ ```bash
110
+ uvicorn main:app --reload --host 0.0.0.0 --port 8000
111
+ ```
112
+ 3. The server will start, and you should see logs indicating that models are being loaded.
__pycache__/main.cpython-310.pyc ADDED
Binary file (31.2 kB). View file
 
gradio_app.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # gradio_app.py
2
+ import gradio as gr
3
+ import io
4
+ import os
5
+ import torch
6
+ from parler_tts import ParlerTTSForConditionalGeneration
7
+ from transformers import AutoTokenizer, AutoModel # CHANGED: Using AutoModel as per model card
8
+ import numpy as np
9
+ import google.generativeai as genai
10
+ import asyncio
11
+ import librosa
12
+ import torchaudio # Often used by models like this for audio loading/processing internally or as input type
13
+
14
+ # --- Configuration ---
15
+ ASR_MODEL_NAME = "ai4bharat/indic-conformer-600m-multilingual"
16
+ TARGET_SAMPLE_RATE = 16000 # Model expects 16kHz
17
+
18
+ TTS_MODEL_NAME = "ai4bharat/indic-parler-tts"
19
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "AIzaSyD6x3Yoby4eQ6QL2kaaG_Rz3fG3rh7wPB8")
20
+ GEMINI_MODEL_NAME_GRADIO = "gemini-1.5-flash-latest"
21
+
22
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
+ # torch_dtype for ParlerTTS, Gemini etc. For ASR model, it might handle its own precision.
24
+
25
+ # --- Global Model Variables ---
26
+ asr_model_gradio = None # This will be the AutoModel instance
27
+
28
+ gemini_model_instance_gradio = None
29
+ tts_model_gradio = None
30
+ tts_tokenizer_gradio = None # For ParlerTTS
31
+
32
+ # --- Model Loading & API Configuration ---
33
+ def load_all_resources_gradio():
34
+ global asr_model_gradio, tts_model_gradio, tts_tokenizer_gradio, gemini_model_instance_gradio
35
+ print(f"Gradio: Loading resources. ASR will be on device: {DEVICE}")
36
+
37
+ if asr_model_gradio is None:
38
+ print(f"Gradio: Loading ASR model: {ASR_MODEL_NAME} using AutoModel")
39
+ try:
40
+ # Load using AutoModel as per the model card's implication
41
+ asr_model_gradio = AutoModel.from_pretrained(ASR_MODEL_NAME, trust_remote_code=True)
42
+ asr_model_gradio.to(DEVICE) # Move model to device
43
+ # The model might handle its own precision (e.g. .half()) internally if `trust_remote_code` allows
44
+ # Or you might need to call asr_model_gradio.half() if it supports it and you're on CUDA.
45
+ if DEVICE == "cuda" and hasattr(asr_model_gradio, 'half'):
46
+ print("Gradio: Applying .half() to ASR model.")
47
+ asr_model_gradio.half()
48
+ asr_model_gradio.eval()
49
+ print(f"Gradio: ASR model ({ASR_MODEL_NAME}) loaded using AutoModel.")
50
+ except Exception as e:
51
+ print(f"Gradio: Failed to load ASR model {ASR_MODEL_NAME} using AutoModel: {e}")
52
+ import traceback
53
+ traceback.print_exc()
54
+ asr_model_gradio = None
55
+
56
+ if tts_model_gradio is None: # ParlerTTS loading
57
+ print(f"Gradio: Loading IndicParler-TTS model: {TTS_MODEL_NAME}")
58
+ # Ensure ParlerTTS specific tokenizer is loaded for TTS
59
+ # Note: ASR model might have its own internal tokenizer/processor handled by its custom code
60
+ tts_parler_tokenizer = AutoTokenizer.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True)
61
+ tts_model_gradio = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True).to(DEVICE)
62
+ tts_tokenizer_gradio = tts_parler_tokenizer
63
+ print("Gradio: IndicParler-TTS model loaded.")
64
+
65
+ if gemini_model_instance_gradio is None: # Gemini loading
66
+ if not GEMINI_API_KEY:
67
+ print("Gradio: GEMINI_API_KEY not found. LLM functionality via Gemini will be limited.")
68
+ else:
69
+ try:
70
+ genai.configure(api_key=GEMINI_API_KEY)
71
+ gemini_model_instance_gradio = genai.GenerativeModel(GEMINI_MODEL_NAME_GRADIO)
72
+ print(f"Gradio: Gemini API configured with model: {GEMINI_MODEL_NAME_GRADIO}")
73
+ except Exception as e:
74
+ print(f"Gradio: Failed to configure Gemini API: {e}")
75
+ gemini_model_instance_gradio = None
76
+
77
+ print("Gradio: All resources loaded (or attempted).")
78
+
79
+
80
+ # --- Helper Functions ---
81
+ def transcribe_audio_gradio(audio_input_tuple):
82
+ if asr_model_gradio is None:
83
+ return f"Error: ASR model ({ASR_MODEL_NAME}) not loaded."
84
+
85
+ if audio_input_tuple is None:
86
+ print("Gradio: No audio provided to transcribe_audio_gradio.")
87
+ return "No audio provided."
88
+
89
+ sample_rate, audio_numpy = audio_input_tuple
90
+
91
+ if audio_numpy is None or audio_numpy.size == 0:
92
+ print("Gradio: Audio numpy array is empty.")
93
+ return "Empty audio received."
94
+
95
+ # Ensure audio is mono float32, which is a common expectation
96
+ if audio_numpy.ndim > 1:
97
+ if audio_numpy.shape[0] == 2 and audio_numpy.ndim == 2:
98
+ audio_numpy = librosa.to_mono(audio_numpy)
99
+ elif audio_numpy.shape[1] == 2 and audio_numpy.ndim == 2:
100
+ audio_numpy = np.mean(audio_numpy, axis=1)
101
+
102
+ if audio_numpy.dtype != np.float32:
103
+ if np.issubdtype(audio_numpy.dtype, np.integer):
104
+ audio_numpy = audio_numpy.astype(np.float32) / np.iinfo(audio_numpy.dtype).max
105
+ else:
106
+ audio_numpy = audio_numpy.astype(np.float32)
107
+
108
+ # Resample to TARGET_SAMPLE_RATE (16kHz)
109
+ if sample_rate != TARGET_SAMPLE_RATE:
110
+ print(f"Gradio: Resampling audio from {sample_rate} Hz to {TARGET_SAMPLE_RATE} Hz.")
111
+ try:
112
+ audio_numpy = librosa.resample(y=audio_numpy, orig_sr=sample_rate, target_sr=TARGET_SAMPLE_RATE)
113
+ # After resampling, the audio_numpy is at TARGET_SAMPLE_RATE
114
+ except Exception as e:
115
+ print(f"Gradio: Error during resampling: {e}")
116
+ return f"Error during audio resampling: {str(e)}"
117
+
118
+ try:
119
+ print(f"Gradio: Preparing to transcribe with {ASR_MODEL_NAME}. Input audio shape: {audio_numpy.shape}")
120
+
121
+ # The model card example `model(wav, "hi", "ctc")` implies it might take a waveform tensor.
122
+ # We have a numpy array. We need to convert it to a PyTorch tensor.
123
+ # The model card uses torchaudio.load which returns a tensor.
124
+ # Let's convert our numpy array to a tensor and ensure it's on the correct device.
125
+
126
+ # Ensure the audio_numpy is 1D as expected by many ASR models for a single channel
127
+ if audio_numpy.ndim > 1:
128
+ audio_numpy = audio_numpy.squeeze() # Attempt to remove singleton dimensions
129
+ if audio_numpy.ndim > 1 : # If still more than 1D, problem
130
+ print(f"Gradio: Audio numpy array for ASR has unexpected dimensions after processing: {audio_numpy.shape}")
131
+ return "Error: Audio processing resulted in unexpected dimensions."
132
+
133
+ wav_tensor = torch.from_numpy(audio_numpy).to(DEVICE)
134
+ # The model might expect a batch dimension, e.g., [1, num_samples]
135
+ if wav_tensor.ndim == 1:
136
+ wav_tensor = wav_tensor.unsqueeze(0)
137
+
138
+ print(f"Gradio: Transcribing with {ASR_MODEL_NAME} using CTC. Input tensor shape: {wav_tensor.shape}")
139
+
140
+ # Perform ASR with CTC decoding (you can choose "rnnt" if preferred and supported)
141
+ # The language code "hi" is for Hindi. You might want to make this configurable
142
+ # or see if the model supports language auto-detection if you pass None or omit it.
143
+ # For now, assuming "hi" or that the model handles mixed language if lang_id is not strictly enforced.
144
+ # The model card doesn't specify if language ID is optional or how auto-detection works.
145
+ # Let's try "auto" or a common language like "en" or "hi" to start.
146
+ # The model card indicates training on 22 languages, so it's multilingual.
147
+ # If language_id is required, you'll need to provide it.
148
+ # Let's assume for now we try with a common Indian language or let the model try to auto-detect if "auto" or None is valid.
149
+ # The snippet "model(wav, "hi", "ctc")" is specific.
150
+
151
+ # The `model()` call is synchronous. Gradio handles this in a thread.
152
+ with torch.no_grad(): # Good practice for inference
153
+ transcription_result = asr_model_gradio(wav_tensor, "hi", "ctc") # Using lang_id="hi" and strategy="ctc" as per example
154
+
155
+ # The output format needs to be checked. The model card implies it's the transcribed string directly.
156
+ # It might be a list of transcriptions if batching occurs, or a dict.
157
+ if isinstance(transcription_result, list) and len(transcription_result) > 0:
158
+ transcribed_text = transcription_result[0] # Assuming first result for non-batched input
159
+ elif isinstance(transcription_result, str):
160
+ transcribed_text = transcription_result
161
+ else:
162
+ print(f"Gradio: Unexpected ASR result format: {type(transcription_result)}, value: {transcription_result}")
163
+ transcribed_text = "ASR result format not recognized."
164
+
165
+ transcribed_text = transcribed_text.strip()
166
+ print(f"Gradio: Transcription ({ASR_MODEL_NAME}, CTC): {transcribed_text}")
167
+ return transcribed_text if transcribed_text else "Transcription resulted in empty text."
168
+ except Exception as e:
169
+ print(f"Gradio: Error during {ASR_MODEL_NAME} transcription (AutoModel callable): {e}")
170
+ import traceback
171
+ traceback.print_exc()
172
+ return f"Error during transcription ({ASR_MODEL_NAME}): {str(e)}"
173
+
174
+
175
+ # ... (Gemini LLM and TTS functions remain the same) ...
176
+ def generate_gemini_response_gradio(text_input: str):
177
+ if not gemini_model_instance_gradio:
178
+ return "Error: Gemini LLM not configured or API key missing."
179
+ if not isinstance(text_input, str) or not text_input.strip() or text_input.startswith("Error:") or "No audio provided" in text_input or "Transcription resulted in empty text" in text_input or "Empty audio received" in text_input or "ASR result format not recognized" in text_input:
180
+ print(f"Gradio: Invalid input to Gemini: '{text_input}'. Skipping LLM response.")
181
+ return "LLM (Gemini) skipped due to transcription issue or no input."
182
+ try:
183
+ print(f"Gradio: Sending to Gemini: '{text_input}'")
184
+ full_prompt = f"User: {text_input}\nAssistant:"
185
+ response = gemini_model_instance_gradio.generate_content(full_prompt)
186
+ response_text = ""
187
+ if response.candidates and response.candidates[0].content.parts:
188
+ response_text = response.candidates[0].content.parts[0].text.strip()
189
+ else:
190
+ feedback_info = ""
191
+ if hasattr(response, 'prompt_feedback') and response.prompt_feedback:
192
+ feedback_info = f" Feedback: {response.prompt_feedback}"
193
+ print(f"Gradio: Gemini response did not contain expected content.{feedback_info}")
194
+ response_text = f"I'm sorry, I couldn't generate a response for that (Gemini).{feedback_info}"
195
+
196
+ print(f"Gradio: Gemini LLM Response: {response_text}")
197
+ return response_text if response_text else "Gemini LLM generated an empty response."
198
+ except Exception as e:
199
+ print(f"Gradio: Error during Gemini LLM generation: {e}")
200
+ import traceback
201
+ traceback.print_exc()
202
+ return f"Error during Gemini LLM generation: {str(e)}"
203
+
204
+ def synthesize_speech_gradio(text_input: str, description: str = "A clear, female voice speaking in English."):
205
+ if tts_model_gradio is None or tts_tokenizer_gradio is None:
206
+ return "Error: TTS model or its tokenizer not loaded."
207
+ if not isinstance(text_input, str) or not text_input.strip() or text_input.startswith("Error:") or "LLM skipped" in text_input or "generated an empty response" in text_input or "not configured" in text_input or "ASR result format not recognized" in text_input :
208
+ print(f"Gradio: Invalid input to TTS: '{text_input}'. Skipping synthesis.")
209
+ return "TTS skipped due to LLM issue or no input."
210
+ try:
211
+ print(f"Gradio: Synthesizing speech for: '{text_input}'")
212
+ description_tokenized = tts_tokenizer_gradio(description, return_tensors="pt", padding=True, truncation=True, max_length=128)
213
+ description_ids = description_tokenized.input_ids.to(DEVICE)
214
+ description_attention_mask = description_tokenized.attention_mask.to(DEVICE)
215
+
216
+ prompt_tokenized = tts_tokenizer_gradio(text_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
217
+ prompt_ids = prompt_tokenized.input_ids.to(DEVICE)
218
+
219
+ if prompt_ids.shape[-1] == 0: # Check if tokenized prompt is empty
220
+ print(f"Gradio: Tokenized prompt for TTS is empty. Text was: '{text_input}'. Skipping synthesis.")
221
+ return "TTS skipped: Input text resulted in empty tokens."
222
+
223
+
224
+ generation = tts_model_gradio.generate(
225
+ input_ids=description_ids,
226
+ attention_mask=description_attention_mask,
227
+ prompt_input_ids=prompt_ids,
228
+ do_sample=True, temperature=0.7, top_k=50, top_p=0.95
229
+ ).cpu().numpy().squeeze()
230
+
231
+ sampling_rate = tts_model_gradio.config.sampling_rate
232
+ print(f"Gradio: Speech synthesized. Array shape: {generation.shape}, Sample rate: {sampling_rate}")
233
+ return (sampling_rate, generation)
234
+ except Exception as e:
235
+ print(f"Gradio: Error during speech synthesis: {e}")
236
+ import traceback
237
+ traceback.print_exc()
238
+ if "You need to specify either `text` or `text_target`" in str(e):
239
+ return "Error in TTS: Model requires 'text' or 'text_target'. Input might be too short or problematic."
240
+ return f"Error during speech synthesis: {str(e)}"
241
+
242
+ # --- Gradio Interface Definition ---
243
+ load_all_resources_gradio()
244
+
245
+ def full_pipeline_gradio(audio_input):
246
+ transcribed_text_output = transcribe_audio_gradio(audio_input)
247
+ print(f"DEBUG full_pipeline_gradio - Step 1 (Transcription): '{transcribed_text_output}' (type: {type(transcribed_text_output)})")
248
+ llm_response_text_output = generate_gemini_response_gradio(transcribed_text_output)
249
+ print(f"DEBUG full_pipeline_gradio - Step 2 (LLM Response): '{llm_response_text_output}' (type: {type(llm_response_text_output)})")
250
+ tts_synthesis_result = synthesize_speech_gradio(llm_response_text_output)
251
+ final_audio_output = None
252
+ if isinstance(tts_synthesis_result, tuple) and len(tts_synthesis_result) == 2 and isinstance(tts_synthesis_result[1], np.ndarray):
253
+ final_audio_output = tts_synthesis_result
254
+ print(f"DEBUG full_pipeline_gradio - Step 3 (TTS Success): Audio tuple with shape {tts_synthesis_result[1].shape if isinstance(tts_synthesis_result[1], np.ndarray) else 'N/A'}")
255
+ else:
256
+ error_message_from_tts = str(tts_synthesis_result) if isinstance(tts_synthesis_result, str) else "TTS synthesis failed or returned unexpected type"
257
+ print(f"DEBUG full_pipeline_gradio - Step 3 (TTS Failed/Non-audio): {error_message_from_tts}. Providing silent audio.")
258
+ # Append TTS error to LLM text only if LLM text was valid
259
+ if llm_response_text_output and not llm_response_text_output.startswith("Error:") and "LLM skipped" not in llm_response_text_output and "ASR result format not recognized" not in llm_response_text_output:
260
+ llm_response_text_output = f"{llm_response_text_output} | (TTS Problem: {error_message_from_tts})"
261
+ elif not llm_response_text_output or llm_response_text_output.startswith("Error:") or "LLM skipped" in llm_response_text_output or "ASR result format not recognized" in llm_response_text_output:
262
+ # If LLM already had an error, just keep that error, maybe note TTS also had an issue
263
+ llm_response_text_output = f"{llm_response_text_output} (TTS also had an issue: {error_message_from_tts})"
264
+
265
+ default_sample_rate = tts_model_gradio.config.sampling_rate if tts_model_gradio and hasattr(tts_model_gradio, 'config') else TARGET_SAMPLE_RATE
266
+ final_audio_output = (default_sample_rate, np.array([0.0], dtype=np.float32))
267
+ print(f"DEBUG full_pipeline_gradio - Step 3 (TTS Fallback): Silent audio tuple")
268
+ print(f"DEBUG full_pipeline_gradio - RETURNING: Transcription='{transcribed_text_output}', LLM_Text='{llm_response_text_output}', Audio_Type={type(final_audio_output)}")
269
+ return transcribed_text_output, llm_response_text_output, final_audio_output
270
+
271
+ with gr.Blocks(title="Conversational AI Demo") as demo:
272
+ gr.Markdown("# Conversational AI Demo (STT -> Gemini LLM -> TTS)")
273
+ with gr.Row():
274
+ audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Speak Here")
275
+ process_button = gr.Button("Process Audio")
276
+ with gr.Accordion("Outputs", open=True):
277
+ transcription_out = gr.Textbox(label="You Said (Transcription)", lines=2)
278
+ llm_response_out = gr.Textbox(label="Gemini Assistant Says (Text)", lines=5)
279
+ audio_out = gr.Audio(label="Assistant Says (Audio)")
280
+
281
+ process_button.click(
282
+ fn=full_pipeline_gradio,
283
+ inputs=[audio_in],
284
+ outputs=[transcription_out, llm_response_out, audio_out]
285
+ )
286
+ gr.Markdown("---")
287
+ gr.Markdown("### How to Use:")
288
+ gr.Markdown("1. Ensure your `GEMINI_API_KEY` environment variable is set.")
289
+ gr.Markdown("2. Click into the 'Speak Here' box and record your audio.")
290
+ gr.Markdown("3. Click the 'Process Audio' button.")
291
+ gr.Markdown("4. View the transcription, Gemini's text response, and listen to the audio response.")
292
+
293
+ if __name__ == "__main__":
294
+ demo.launch(share=False)
infereless.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ URL = 'https://serverless-region-v1.inferless.com/api/v1/parler-tts-streaming-1_ae4e81bb5d604799b573df3f0b3c9518/infer'
4
+ headers = {"Content-Type": "application/json", "Authorization": "Bearer 1e01145781a0639d830555d5e4e4e5e1752726750db75e995e0e246f32c4b7c9f442bd6f8caec8acc6b9684ec78e5b633db04370815ca1748bf5a7db80245411"}
5
+
6
+ data = json.loads('''{
7
+ "parameters": {
8
+ "prompt_value": "A male speaker with a low-pitched voice delivering his words at a fast pace in a small, confined space with a very clear audio and an animated tone.",
9
+ "input_value": "Remember - this is only the first iteration of the model! To improve the prosody and naturalness of the speech further, we're scaling up the amount of training data by a factor of five times."
10
+ }
11
+ }''')
12
+
13
+ response = requests.post(URL, headers=headers, data=json.dumps(data))
14
+ print(response.json())
main.py ADDED
@@ -0,0 +1,740 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+ import asyncio
3
+ import base64
4
+ import io
5
+ import logging
6
+ import os
7
+ from threading import Thread, Event # Added Event for better thread control
8
+ import time # For timeout checks
9
+
10
+ import soundfile as sf
11
+ import torch
12
+ import uvicorn
13
+ import whisper
14
+ from fastapi import FastAPI, File, UploadFile, WebSocket, WebSocketDisconnect
15
+ from fastapi.responses import HTMLResponse, JSONResponse
16
+ from fastapi.middleware.cors import CORSMiddleware
17
+ from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
18
+ from transformers import AutoTokenizer, GenerationConfig # Keep transformers.GenerationConfig
19
+ import google.generativeai as genai
20
+ import numpy as np
21
+
22
+ # --- Configuration ---
23
+ WHISPER_MODEL_SIZE = os.getenv("WHISPER_MODEL_SIZE", "tiny")
24
+ TTS_MODEL_NAME = "ai4bharat/indic-parler-tts"
25
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "AIzaSyD6x3Yoby4eQ6QL2kaaG_Rz3fG3rh7wPB8")
26
+ GEMINI_MODEL_NAME = "gemini-1.5-flash-latest"
27
+
28
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
+
30
+ attn_implementation = "flash_attention_2" if torch.cuda.is_available() else "eager"
31
+
32
+ torch_dtype_tts = torch.bfloat16 if DEVICE == "cuda" and torch.cuda.is_bf16_supported() else (torch.float16 if DEVICE == "cuda" else torch.float32)
33
+ torch_dtype_whisper = torch.float16 if DEVICE == "cuda" else torch.float32
34
+
35
+ TTS_DEFAULT_PARAMS = {
36
+ "do_sample": True,
37
+ "temperature": 1.0,
38
+ "top_k": 50,
39
+ "top_p": 0.95,
40
+ "min_new_tokens": 5, # Reduced for quicker start with streamer
41
+ # "max_new_tokens": 256, # Optional global cap
42
+ }
43
+
44
+ # --- Logging ---
45
+ logging.basicConfig(level=logging.INFO)
46
+ logger = logging.getLogger(__name__)
47
+
48
+ # --- FastAPI App Initialization ---
49
+ app = FastAPI(title="Conversational AI Chatbot with Enhanced Stream Abortion")
50
+
51
+ app.add_middleware(
52
+ CORSMiddleware,
53
+ allow_origins=["*"],
54
+ allow_credentials=True,
55
+ allow_methods=["*"],
56
+ allow_headers=["*"],
57
+ )
58
+
59
+ # --- Global Model Variables ---
60
+ whisper_model = None
61
+ gemini_model_instance = None
62
+ tts_model = None
63
+ tts_tokenizer = None
64
+ # We will build the GenerationConfig object from TTS_DEFAULT_PARAMS inside the functions
65
+ # or store it globally if preferred, initialized from transformers.GenerationConfig
66
+
67
+ # --- Model Loading & API Configuration ---
68
+ @app.on_event("startup")
69
+ async def load_resources():
70
+ global whisper_model, tts_model, tts_tokenizer, gemini_model_instance
71
+ logger.info(f"Loading local models. Whisper on {DEVICE} with {torch_dtype_whisper}, TTS on {DEVICE} with {torch_dtype_tts}")
72
+
73
+ try:
74
+ logger.info(f"Loading Whisper model: {WHISPER_MODEL_SIZE}")
75
+ whisper_model = whisper.load_model(WHISPER_MODEL_SIZE, device=DEVICE)
76
+ logger.info("Whisper model loaded successfully.")
77
+
78
+ logger.info(f"Loading IndicParler-TTS model: {TTS_MODEL_NAME}")
79
+ tts_model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL_NAME, attn_implementation=attn_implementation).to(DEVICE, dtype=torch_dtype_tts)
80
+ tts_tokenizer = AutoTokenizer.from_pretrained(TTS_MODEL_NAME)
81
+
82
+ if tts_tokenizer:
83
+ if tts_tokenizer.pad_token_id is not None:
84
+ TTS_DEFAULT_PARAMS["pad_token_id"] = tts_tokenizer.pad_token_id
85
+ # ParlerTTS uses a special token_id for silence, not eos_token_id for generation end.
86
+ # eos_token_id is more for text models.
87
+ # if tts_tokenizer.eos_token_id is not None:
88
+ # TTS_DEFAULT_PARAMS["eos_token_id"] = tts_tokenizer.eos_token_id
89
+ logger.info(f"IndicParler-TTS model loaded. Default generation params: {TTS_DEFAULT_PARAMS}")
90
+
91
+ if not GEMINI_API_KEY:
92
+ logger.warning("GEMINI_API_KEY not found. LLM functionality will be limited.")
93
+ else:
94
+ try:
95
+ genai.configure(api_key=GEMINI_API_KEY)
96
+ gemini_model_instance = genai.GenerativeModel(GEMINI_MODEL_NAME)
97
+ logger.info(f"Gemini API configured with model: {GEMINI_MODEL_NAME}")
98
+ except Exception as e:
99
+ logger.error(f"Failed to configure Gemini API: {e}", exc_info=True)
100
+ gemini_model_instance = None
101
+
102
+ except Exception as e:
103
+ logger.error(f"Error loading models: {e}", exc_info=True)
104
+ logger.info("Local models and API configurations loaded.")
105
+
106
+
107
+ # --- Helper Functions ---
108
+ async def transcribe_audio_bytes(audio_bytes: bytes) -> str:
109
+ if not whisper_model:
110
+ raise RuntimeError("Whisper model not loaded.")
111
+ temp_audio_path = f"temp_audio_main_{os.urandom(4).hex()}.wav"
112
+ try:
113
+ with open(temp_audio_path, "wb") as f:
114
+ f.write(audio_bytes)
115
+ result = whisper_model.transcribe(temp_audio_path, fp16=(DEVICE == "cuda" and torch_dtype_whisper == torch.float16))
116
+ transcribed_text = result["text"].strip()
117
+ logger.info(f"Transcription: {transcribed_text}")
118
+ return transcribed_text
119
+ except Exception as e:
120
+ logger.error(f"Error during transcription: {e}", exc_info=True)
121
+ return ""
122
+ finally:
123
+ if os.path.exists(temp_audio_path):
124
+ try:
125
+ os.remove(temp_audio_path)
126
+ except Exception as e_del:
127
+ logger.error(f"Error deleting temp audio file {temp_audio_path}: {e_del}")
128
+
129
+
130
+ async def generate_gemini_response(text: str) -> str:
131
+ if not gemini_model_instance:
132
+ logger.error("Gemini model instance not available.")
133
+ return "Sorry, the language model is currently unavailable."
134
+ try:
135
+ full_prompt = f"User: {text}\nAssistant:"
136
+ loop = asyncio.get_event_loop()
137
+ response = await loop.run_in_executor(None, gemini_model_instance.generate_content, full_prompt)
138
+ response_text = "I'm sorry, I couldn't generate a response for that."
139
+ if hasattr(response, 'text') and response.text: # For simple text responses
140
+ response_text = response.text.strip()
141
+ elif response.parts: # New way to access parts for gemini-1.5-flash and pro
142
+ response_text = "".join(part.text for part in response.parts).strip()
143
+ elif response.candidates and response.candidates[0].content.parts: # Older way
144
+ response_text = response.candidates[0].content.parts[0].text.strip()
145
+ else:
146
+ safety_feedback = ""
147
+ if hasattr(response, 'prompt_feedback') and response.prompt_feedback:
148
+ safety_feedback = f" Safety Feedback: {response.prompt_feedback}"
149
+ elif response.candidates and hasattr(response.candidates[0], 'finish_reason') and response.candidates[0].finish_reason != "STOP":
150
+ safety_feedback = f" Finish Reason: {response.candidates[0].finish_reason}"
151
+ logger.warning(f"Gemini response might be empty or blocked.{safety_feedback}")
152
+ logger.info(f"Gemini LLM Response: {response_text}")
153
+ return response_text
154
+ except Exception as e:
155
+ logger.error(f"Error during Gemini LLM generation: {e}", exc_info=True)
156
+ return "Sorry, I encountered an error trying to respond."
157
+
158
+
159
+ async def synthesize_speech_streaming(text: str, description: str = "A clear, female voice speaking in English.", play_steps_in_s: float = 0.4, cancellation_event: Event = Event()):
160
+ if not tts_model or not tts_tokenizer:
161
+ logger.error("TTS model or tokenizer not loaded.")
162
+ if cancellation_event and cancellation_event.is_set(): logger.info("TTS cancelled before start."); yield b""; return
163
+ yield b""
164
+ return
165
+
166
+ if not text or not text.strip():
167
+ logger.warning("TTS input text is empty. Yielding empty audio.")
168
+ if cancellation_event and cancellation_event.is_set(): logger.info("TTS cancelled before start (empty text)."); yield b""; return
169
+ yield b""
170
+ return
171
+
172
+ streamer = None
173
+ thread = None
174
+
175
+ try:
176
+ logger.info(f"Starting TTS streaming with ParlerTTSStreamer for: \"{text[:50]}...\"")
177
+
178
+ # Ensure sampling_rate is correctly accessed from the model's config
179
+ # For ParlerTTS, it's usually under model.config.audio_encoder.sampling_rate
180
+ if hasattr(tts_model.config, 'audio_encoder') and hasattr(tts_model.config.audio_encoder, 'sampling_rate'):
181
+ sampling_rate = tts_model.config.audio_encoder.sampling_rate
182
+ else:
183
+ logger.warning("Could not find tts_model.config.audio_encoder.sampling_rate, defaulting to 24000")
184
+ sampling_rate = 24000 # A common default for ParlerTTS if not found
185
+
186
+ try:
187
+ frame_rate = getattr(tts_model.config.audio_encoder, 'frame_rate', 100)
188
+ except AttributeError:
189
+ logger.warning("frame_rate not found in tts_model.config.audio_encoder. Using default of 100 Hz for play_steps calculation.")
190
+ frame_rate = 100
191
+
192
+ play_steps = int(frame_rate * play_steps_in_s)
193
+ if play_steps == 0 : play_steps = 1
194
+
195
+ logger.info(f"Streamer params: sampling_rate={sampling_rate}, frame_rate={frame_rate}, play_steps_in_s={play_steps_in_s}, play_steps={play_steps}")
196
+
197
+ streamer = ParlerTTSStreamer(tts_model, device=DEVICE, play_steps=play_steps)
198
+
199
+ description_inputs = tts_tokenizer(description, return_tensors="pt")
200
+ prompt_inputs = tts_tokenizer(text, return_tensors="pt")
201
+
202
+ gen_config_dict = TTS_DEFAULT_PARAMS.copy()
203
+ # ParlerTTS generate method might not take a GenerationConfig object directly,
204
+ # but rather individual kwargs. The streamer example passes them as kwargs.
205
+ # We ensure pad_token_id and eos_token_id are set if the tokenizer has them.
206
+ if tts_tokenizer.pad_token_id is not None:
207
+ gen_config_dict["pad_token_id"] = tts_tokenizer.pad_token_id
208
+ # ParlerTTS might not use eos_token_id in the same way as text models.
209
+ # if tts_tokenizer.eos_token_id is not None:
210
+ # gen_config_dict["eos_token_id"] = tts_tokenizer.eos_token_id
211
+
212
+
213
+ thread_generation_kwargs = {
214
+ "input_ids": description_inputs.input_ids.to(DEVICE),
215
+ "prompt_input_ids": prompt_inputs.input_ids.to(DEVICE),
216
+ "attention_mask": description_inputs.attention_mask.to(DEVICE) if hasattr(description_inputs, 'attention_mask') else None,
217
+ "streamer": streamer,
218
+ **gen_config_dict # Spread the generation parameters
219
+ }
220
+ if thread_generation_kwargs["attention_mask"] is None:
221
+ del thread_generation_kwargs["attention_mask"]
222
+
223
+ def _generate_in_thread():
224
+ try:
225
+ logger.info(f"TTS generation thread started.")
226
+ with torch.no_grad():
227
+ tts_model.generate(**thread_generation_kwargs)
228
+ logger.info("TTS generation thread finished model.generate().")
229
+ except Exception as e_thread:
230
+ logger.error(f"Error in TTS generation thread: {e_thread}", exc_info=True)
231
+ finally:
232
+ if streamer: streamer.end()
233
+ logger.info("TTS generation thread called streamer.end().")
234
+
235
+ thread = Thread(target=_generate_in_thread)
236
+ thread.daemon = True
237
+ thread.start()
238
+
239
+ loop = asyncio.get_event_loop()
240
+ while True:
241
+ if cancellation_event and cancellation_event.is_set():
242
+ logger.info("TTS streaming cancelled by event.")
243
+ break
244
+
245
+ try:
246
+ # Run the blocking streamer.__next__() in an executor
247
+ audio_chunk_tensor = await loop.run_in_executor(None, streamer.__next__)
248
+
249
+ if audio_chunk_tensor is None:
250
+ logger.info("Streamer yielded None explicitly, ending stream.")
251
+ break
252
+ # This check for numel == 0 is important as streamer might yield empty tensors
253
+ if not isinstance(audio_chunk_tensor, torch.Tensor) or audio_chunk_tensor.numel() == 0:
254
+ # REMOVED: if streamer.is_done(): (AttributeError)
255
+ # Instead, rely on StopIteration or explicit None from streamer
256
+ await asyncio.sleep(0.01) # Small sleep if empty but not done
257
+ continue
258
+
259
+ audio_chunk_np = audio_chunk_tensor.cpu().to(torch.float32).numpy().squeeze()
260
+ if audio_chunk_np.size == 0:
261
+ continue
262
+
263
+ audio_chunk_int16 = np.clip(audio_chunk_np * 32767, -32768, 32767).astype(np.int16)
264
+ yield audio_chunk_int16.tobytes()
265
+ # No need for sleep here if chunks are substantial, client will process
266
+ # await asyncio.sleep(0.001) # Can be removed or made very small
267
+
268
+ except StopIteration:
269
+ logger.info("Streamer finished (StopIteration).")
270
+ break
271
+ except Exception as e_stream_iter:
272
+ logger.error(f"Error iterating streamer: {e_stream_iter}", exc_info=True)
273
+ break
274
+
275
+ logger.info(f"Finished TTS streaming iteration for: \"{text[:50]}...\"")
276
+
277
+ except Exception as e:
278
+ logger.error(f"Error in synthesize_speech_streaming function: {e}", exc_info=True)
279
+ yield b""
280
+ finally:
281
+ logger.info("Exiting synthesize_speech_streaming. Ensuring streamer is ended and thread is joined.")
282
+ if streamer:
283
+ streamer.end()
284
+ if thread and thread.is_alive():
285
+ logger.info("Waiting for TTS generation thread to complete in finally block...")
286
+ final_join_start_time = time.time()
287
+ thread.join(timeout=2.0)
288
+ if thread.is_alive():
289
+ logger.warning(f"TTS generation thread still alive after {time.time() - final_join_start_time:.2f}s in finally block.")
290
+
291
+
292
+ # --- FastAPI HTTP Endpoints ---
293
+ @app.post("/api/stt", summary="Speech to Text")
294
+ async def speech_to_text_endpoint(file: UploadFile = File(...)):
295
+ if not whisper_model:
296
+ return JSONResponse(content={"error": "Whisper model not loaded"}, status_code=503)
297
+ try:
298
+ audio_bytes = await file.read()
299
+ transcribed_text = await transcribe_audio_bytes(audio_bytes)
300
+ return {"transcription": transcribed_text}
301
+ except Exception as e:
302
+ return JSONResponse(content={"error": str(e)}, status_code=500)
303
+
304
+ @app.post("/api/llm", summary="LLM Response Generation (Gemini)")
305
+ async def llm_endpoint(payload: dict):
306
+ if not gemini_model_instance:
307
+ return JSONResponse(content={"error": "Gemini LLM not configured or API key missing"}, status_code=503)
308
+ try:
309
+ text = payload.get("text")
310
+ if not text:
311
+ return JSONResponse(content={"error": "No text provided"}, status_code=400)
312
+ response = await generate_gemini_response(text)
313
+ return {"response": response}
314
+ except Exception as e:
315
+ return JSONResponse(content={"error": str(e)}, status_code=500)
316
+
317
+ @app.post("/api/tts", summary="Text to Speech (Non-Streaming for HTTP)")
318
+ async def text_to_speech_endpoint(payload: dict):
319
+ if not tts_model or not tts_tokenizer:
320
+ return JSONResponse(content={"error": "TTS model/tokenizer not loaded"}, status_code=503)
321
+ try:
322
+ text = payload.get("text")
323
+ description = payload.get("description", "A clear, female voice speaking in English.")
324
+ if not text:
325
+ return JSONResponse(content={"error": "No text provided"}, status_code=400)
326
+
327
+ description_inputs = tts_tokenizer(description, return_tensors="pt")
328
+ prompt_inputs = tts_tokenizer(text, return_tensors="pt")
329
+
330
+ # Use a GenerationConfig object for clarity and consistency
331
+ gen_config_dict = TTS_DEFAULT_PARAMS.copy()
332
+ if tts_tokenizer.pad_token_id is not None:
333
+ gen_config_dict["pad_token_id"] = tts_tokenizer.pad_token_id
334
+ # if tts_tokenizer.eos_token_id is not None: # ParlerTTS might not use standard eos
335
+ # gen_config_dict["eos_token_id"] = tts_tokenizer.eos_token_id
336
+
337
+ # Create GenerationConfig from transformers
338
+ generation_config_obj = GenerationConfig(**gen_config_dict)
339
+
340
+
341
+ with torch.no_grad():
342
+ generation = tts_model.generate(
343
+ input_ids=description_inputs.input_ids.to(DEVICE),
344
+ prompt_input_ids=prompt_inputs.input_ids.to(DEVICE),
345
+ attention_mask=description_inputs.attention_mask.to(DEVICE) if hasattr(description_inputs, 'attention_mask') else None,
346
+ generation_config=generation_config_obj # Pass the config object
347
+ ).cpu().to(torch.float32).numpy().squeeze()
348
+
349
+ audio_io = io.BytesIO()
350
+ scaled_generation = np.clip(generation * 32767, -32768, 32767).astype(np.int16)
351
+
352
+ current_sampling_rate = tts_model.config.audio_encoder.sampling_rate if hasattr(tts_model.config, 'audio_encoder') else 24000
353
+ sf.write(audio_io, scaled_generation, samplerate=current_sampling_rate, format='WAV', subtype='PCM_16')
354
+ audio_io.seek(0)
355
+ audio_bytes = audio_io.read()
356
+
357
+ if not audio_bytes:
358
+ return JSONResponse(content={"error": "TTS failed to generate audio"}, status_code=500)
359
+ audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
360
+ return {"audio_base64": audio_base64, "format": "wav", "sample_rate": current_sampling_rate}
361
+ except Exception as e:
362
+ logger.error(f"TTS endpoint error: {e}", exc_info=True)
363
+ return JSONResponse(content={"error": str(e)}, status_code=500)
364
+
365
+ # --- WebSocket Endpoint for Real-time Conversation ---
366
+ @app.websocket("/ws/conversation")
367
+ async def conversation_websocket(websocket: WebSocket):
368
+ await websocket.accept()
369
+ logger.info(f"WebSocket connection accepted from: {websocket.client}")
370
+
371
+ tts_cancellation_event = Event() # For this specific connection
372
+
373
+ try:
374
+ while True:
375
+ if websocket.client_state.name != 'CONNECTED': # Check if client disconnected before receive
376
+ logger.info(f"WebSocket client {websocket.client} disconnected before receive.")
377
+ break
378
+
379
+ audio_data = await websocket.receive_bytes()
380
+ logger.info(f"Received {len(audio_data)} bytes of user audio data from {websocket.client}.")
381
+
382
+ if not audio_data:
383
+ logger.warning(f"Received empty audio data from user {websocket.client}.")
384
+ continue
385
+
386
+ transcribed_text = await transcribe_audio_bytes(audio_data)
387
+ if not transcribed_text:
388
+ logger.warning(f"Transcription failed for {websocket.client}.")
389
+ await websocket.send_text("SYSTEM_ERROR: Transcription failed.")
390
+ continue
391
+ await websocket.send_text(f"USER_TRANSCRIPT: {transcribed_text}")
392
+
393
+ llm_response_text = await generate_gemini_response(transcribed_text)
394
+ if not llm_response_text or "Sorry, I encountered an error" in llm_response_text or "unavailable" in llm_response_text:
395
+ logger.warning(f"LLM (Gemini) failed for {websocket.client}: {llm_response_text}")
396
+ await websocket.send_text(f"SYSTEM_ERROR: LLM failed. ({llm_response_text})")
397
+ continue
398
+ await websocket.send_text(f"ASSISTANT_RESPONSE_TEXT: {llm_response_text}")
399
+
400
+ tts_description = "A clear, female voice speaking in English."
401
+
402
+ current_sampling_rate = tts_model.config.audio_encoder.sampling_rate if hasattr(tts_model.config, 'audio_encoder') else 24000
403
+ audio_params_msg = f"TTS_STREAM_START:{{\"sample_rate\": {current_sampling_rate}, \"channels\": 1, \"bit_depth\": 16}}"
404
+ await websocket.send_text(audio_params_msg)
405
+ logger.info(f"Sent to client {websocket.client}: {audio_params_msg}")
406
+
407
+ chunk_count = 0
408
+ tts_cancellation_event.clear() # Reset event for new TTS task
409
+
410
+ async for audio_chunk_bytes in synthesize_speech_streaming(llm_response_text, tts_description, cancellation_event=tts_cancellation_event):
411
+ if not audio_chunk_bytes:
412
+ logger.debug(f"Received empty bytes from streaming generator for {websocket.client}, might be end or error in generator.")
413
+ continue
414
+ try:
415
+ if websocket.client_state.name != 'CONNECTED':
416
+ logger.warning(f"Client {websocket.client} disconnected during TTS stream. Aborting TTS.")
417
+ tts_cancellation_event.set() # Signal TTS thread to stop
418
+ break
419
+ await websocket.send_bytes(audio_chunk_bytes)
420
+ chunk_count += 1
421
+ except Exception as send_err:
422
+ logger.warning(f"Error sending audio chunk to {websocket.client}: {send_err}. Client likely disconnected.")
423
+ tts_cancellation_event.set() # Signal TTS thread to stop
424
+ break
425
+
426
+ if not tts_cancellation_event.is_set(): # Only send END if not cancelled
427
+ logger.info(f"Sent {chunk_count} TTS audio chunks to client {websocket.client}.")
428
+ await websocket.send_text("TTS_STREAM_END")
429
+ logger.info(f"Sent TTS_STREAM_END to client {websocket.client}.")
430
+ else:
431
+ logger.info(f"TTS stream for {websocket.client} was cancelled. Sent {chunk_count} chunks before cancellation.")
432
+
433
+
434
+ except WebSocketDisconnect:
435
+ logger.info(f"WebSocket connection closed by client {websocket.client}.")
436
+ tts_cancellation_event.set() # Signal any active TTS to stop
437
+ except Exception as e:
438
+ logger.error(f"Error in WebSocket conversation with {websocket.client}: {e}", exc_info=True)
439
+ tts_cancellation_event.set() # Signal any active TTS to stop
440
+ try:
441
+ if websocket.client_state.name == 'CONNECTED':
442
+ await websocket.send_text(f"SYSTEM_ERROR: An unexpected error occurred: {str(e)}")
443
+ except Exception: pass
444
+ finally:
445
+ logger.info(f"Cleaning up WebSocket connection for {websocket.client}.")
446
+ tts_cancellation_event.set() # Ensure event is set on any exit path
447
+ if websocket.client_state.name == 'CONNECTED' or websocket.client_state.name == 'CONNECTING':
448
+ try: await websocket.close()
449
+ except Exception: pass
450
+ logger.info(f"WebSocket connection resources cleaned up for {websocket.client}.")
451
+
452
+ # ... (HTML serving and main execution block remain the same) ...
453
+ @app.get("/", response_class=HTMLResponse)
454
+ async def get_home():
455
+ html_content = """
456
+ <!DOCTYPE html>
457
+ <html>
458
+ <head>
459
+ <title>Conversational AI Chatbot (Streaming)</title>
460
+ <style>
461
+ body { font-family: Arial, sans-serif; margin: 20px; background-color: #f4f4f4; color: #333; }
462
+ #chatbox { width: 80%; max-width: 600px; margin: auto; background-color: #fff; padding: 20px; box-shadow: 0 0 10px rgba(0,0,0,0.1); border-radius: 8px; }
463
+ .message { padding: 10px; margin-bottom: 10px; border-radius: 5px; }
464
+ .user { background-color: #e1f5fe; text-align: right; }
465
+ .assistant { background-color: #f1f8e9; }
466
+ .system { background-color: #ffebee; color: #c62828; font-style: italic;}
467
+ #audioPlayerContainer { margin-top: 10px; }
468
+ #audioPlayer { display: none; width: 100%; }
469
+ button { padding: 10px 15px; background-color: #007bff; color: white; border: none; border-radius: 5px; cursor: pointer; margin-top:10px; }
470
+ button:disabled { background-color: #ccc; }
471
+ #status { margin-top: 10px; font-style: italic; color: #666; }
472
+ #transcriptionArea, #llmResponseArea { margin-top: 10px; padding: 5px; border: 1px solid #eee; background: #fafafa; word-wrap: break-word;}
473
+ </style>
474
+ </head>
475
+ <body>
476
+ <div id="chatbox">
477
+ <h2>Real-time AI Chatbot (Streaming TTS)</h2>
478
+ <div id="messages"></div>
479
+ <div id="transcriptionArea"><strong>You (transcribed):</strong> <span id="userTranscript">...</span></div>
480
+ <div id="llmResponseArea"><strong>Assistant (text):</strong> <span id="assistantTranscript">...</span></div>
481
+
482
+ <button id="startRecordButton">Start Recording</button>
483
+ <button id="stopRecordButton" disabled>Stop Recording</button>
484
+ <p id="status">Status: Idle</p>
485
+ <div id="audioPlayerContainer">
486
+ <audio id="audioPlayer" controls></audio>
487
+ </div>
488
+ </div>
489
+
490
+ <script>
491
+ const startRecordButton = document.getElementById('startRecordButton');
492
+ const stopRecordButton = document.getElementById('stopRecordButton');
493
+ const audioPlayer = document.getElementById('audioPlayer');
494
+ const messagesDiv = document.getElementById('messages');
495
+ const statusDiv = document.getElementById('status');
496
+ const userTranscriptSpan = document.getElementById('userTranscript');
497
+ const assistantTranscriptSpan = document.getElementById('assistantTranscript');
498
+
499
+ let websocket;
500
+ let mediaRecorder;
501
+ let userAudioChunks = [];
502
+
503
+ let assistantAudioBufferQueue = [];
504
+ let audioContext;
505
+ let expectedSampleRate;
506
+ let ttsStreaming = false;
507
+ let audioPlaying = false;
508
+ let sourceNode = null;
509
+
510
+ function initAudioContext() {
511
+ if (!audioContext || audioContext.state === 'closed') {
512
+ try {
513
+ audioContext = new (window.AudioContext || window.webkitAudioContext)();
514
+ console.log("AudioContext initialized or re-initialized.");
515
+ } catch (e) {
516
+ console.error("Web Audio API is not supported in this browser.", e);
517
+ addMessage("Error: Web Audio API not supported. Cannot play streamed audio.", "system");
518
+ audioContext = null;
519
+ }
520
+ }
521
+ }
522
+
523
+
524
+ function connectWebSocket() {
525
+ const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
526
+ const wsUrl = `${protocol}//${window.location.host}/ws/conversation`;
527
+ websocket = new WebSocket(wsUrl);
528
+ websocket.binaryType = 'arraybuffer';
529
+
530
+ websocket.onopen = () => {
531
+ statusDiv.textContent = 'Status: Connected. Ready to record.';
532
+ startRecordButton.disabled = false;
533
+ initAudioContext();
534
+ };
535
+
536
+ websocket.onmessage = (event) => {
537
+ if (event.data instanceof ArrayBuffer) {
538
+ if (ttsStreaming && audioContext && expectedSampleRate) {
539
+ const pcmDataInt16 = new Int16Array(event.data);
540
+ if (pcmDataInt16.length > 0) {
541
+ assistantAudioBufferQueue.push(pcmDataInt16);
542
+ playNextChunkFromQueue();
543
+ }
544
+ } else {
545
+ console.warn("Received ArrayBuffer data but not in TTS streaming mode or AudioContext not ready.");
546
+ }
547
+ } else {
548
+ const messageText = event.data;
549
+ if (messageText.startsWith("USER_TRANSCRIPT:")) {
550
+ const transcript = messageText.substring("USER_TRANSCRIPT:".length).trim();
551
+ userTranscriptSpan.textContent = transcript;
552
+ } else if (messageText.startsWith("ASSISTANT_RESPONSE_TEXT:")) {
553
+ const llmResponse = messageText.substring("ASSISTANT_RESPONSE_TEXT:".length).trim();
554
+ assistantTranscriptSpan.textContent = llmResponse;
555
+ addMessage(`Assistant: ${llmResponse}`, 'assistant');
556
+ } else if (messageText.startsWith("TTS_STREAM_START:")) {
557
+ ttsStreaming = true;
558
+ assistantAudioBufferQueue = [];
559
+ audioPlaying = false;
560
+ if (sourceNode) {
561
+ try { sourceNode.stop(); } catch(e) { console.warn("Error stopping previous sourceNode:", e); }
562
+ sourceNode = null;
563
+ }
564
+ audioPlayer.style.display = 'none';
565
+ audioPlayer.src = "";
566
+ try {
567
+ const paramsText = messageText.substring("TTS_STREAM_START:".length);
568
+ const params = JSON.parse(paramsText);
569
+ expectedSampleRate = params.sample_rate;
570
+ initAudioContext();
571
+ statusDiv.textContent = 'Status: Receiving audio stream...';
572
+ addMessage('Assistant (Audio stream starting...)', 'assistant');
573
+ } catch (e) {
574
+ console.error("Could not parse TTS_STREAM_START params:", e);
575
+ statusDiv.textContent = 'Error: Could not parse audio stream parameters.';
576
+ ttsStreaming = false;
577
+ }
578
+ } else if (messageText === "TTS_STREAM_END") {
579
+ ttsStreaming = false;
580
+ if (!audioPlaying && assistantAudioBufferQueue.length === 0) {
581
+ statusDiv.textContent = 'Status: Audio stream finished (or was empty).';
582
+ } else if (!audioPlaying && assistantAudioBufferQueue.length > 0) {
583
+ playNextChunkFromQueue();
584
+ statusDiv.textContent = 'Status: Audio stream finished. Playing remaining...';
585
+ } else {
586
+ statusDiv.textContent = 'Status: Audio stream finished. Playing remaining...';
587
+ }
588
+ addMessage('Assistant (Audio stream ended)', 'assistant');
589
+ } else if (messageText.startsWith("SYSTEM_ERROR:")) {
590
+ const errorMsg = messageText.substring("SYSTEM_ERROR:".length).trim();
591
+ addMessage(`System Error: ${errorMsg}`, 'system');
592
+ statusDiv.textContent = `Error: ${errorMsg}`;
593
+ ttsStreaming = false;
594
+ assistantAudioBufferQueue = [];
595
+ } else {
596
+ addMessage(messageText, 'system');
597
+ }
598
+ }
599
+ };
600
+ websocket.onerror = (error) => {
601
+ console.error('WebSocket Error:', error);
602
+ statusDiv.textContent = 'Status: WebSocket error. Try reconnecting.';
603
+ addMessage('WebSocket Error. Check console.', 'system');
604
+ ttsStreaming = false;
605
+ };
606
+
607
+ websocket.onclose = () => {
608
+ statusDiv.textContent = 'Status: Disconnected. Please refresh to reconnect.';
609
+ startRecordButton.disabled = true;
610
+ stopRecordButton.disabled = true;
611
+ addMessage('Disconnected from server.', 'system');
612
+ ttsStreaming = false;
613
+ if (audioContext && audioContext.state !== 'closed') {
614
+ audioContext.close().catch(e => console.warn("Error closing AudioContext:", e));
615
+ audioContext = null;
616
+ console.log("AudioContext closed.");
617
+ }
618
+ };
619
+ }
620
+
621
+ function playNextChunkFromQueue() {
622
+ if (audioPlaying || assistantAudioBufferQueue.length === 0 || !audioContext || audioContext.state !== 'running' || !expectedSampleRate) {
623
+ if (assistantAudioBufferQueue.length === 0 && !ttsStreaming && !audioPlaying) {
624
+ console.log("Queue empty, not streaming, not playing: Playback complete.");
625
+ statusDiv.textContent = 'Status: Audio playback complete.';
626
+ }
627
+ return;
628
+ }
629
+ audioPlaying = true;
630
+
631
+ const pcmDataInt16 = assistantAudioBufferQueue.shift();
632
+
633
+ const float32Pcm = new Float32Array(pcmDataInt16.length);
634
+ for (let i = 0; i < pcmDataInt16.length; i++) {
635
+ float32Pcm[i] = pcmDataInt16[i] / 32768.0;
636
+ }
637
+
638
+ const audioBuffer = audioContext.createBuffer(1, float32Pcm.length, expectedSampleRate);
639
+ audioBuffer.getChannelData(0).set(float32Pcm);
640
+
641
+ sourceNode = audioContext.createBufferSource();
642
+ sourceNode.buffer = audioBuffer;
643
+ sourceNode.connect(audioContext.destination);
644
+ sourceNode.onended = () => {
645
+ audioPlaying = false;
646
+ if (ttsStreaming || assistantAudioBufferQueue.length > 0) {
647
+ playNextChunkFromQueue();
648
+ } else {
649
+ statusDiv.textContent = 'Status: Audio playback finished.';
650
+ console.log("All queued audio chunks played.");
651
+ }
652
+ };
653
+ sourceNode.start();
654
+ statusDiv.textContent = 'Status: Playing audio chunk...';
655
+ }
656
+
657
+ function addMessage(text, type) {
658
+ const messageElement = document.createElement('div');
659
+ messageElement.classList.add('message', type);
660
+ messageElement.textContent = text;
661
+ messagesDiv.appendChild(messageElement);
662
+ messagesDiv.scrollTop = messagesDiv.scrollHeight;
663
+ }
664
+
665
+ startRecordButton.onclick = async () => {
666
+ if (!websocket || websocket.readyState !== WebSocket.OPEN) {
667
+ alert("WebSocket is not connected. Please wait or refresh.");
668
+ return;
669
+ }
670
+ if (audioContext && audioContext.state === 'suspended') {
671
+ audioContext.resume().catch(e => console.error("Error resuming AudioContext:", e));
672
+ }
673
+ initAudioContext();
674
+
675
+ try {
676
+ const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
677
+ let options = { mimeType: 'audio/webm;codecs=opus' };
678
+ if (!MediaRecorder.isTypeSupported(options.mimeType)) {
679
+ console.warn(`${options.mimeType} is not supported, trying default.`);
680
+ options = {};
681
+ }
682
+ mediaRecorder = new MediaRecorder(stream, options);
683
+ userAudioChunks = [];
684
+
685
+ mediaRecorder.ondataavailable = event => {
686
+ if (event.data.size > 0) userAudioChunks.push(event.data);
687
+ };
688
+
689
+ mediaRecorder.onstop = () => {
690
+ if (userAudioChunks.length === 0) {
691
+ console.log("No audio data recorded.");
692
+ statusDiv.textContent = 'Status: No audio data recorded. Try again.';
693
+ startRecordButton.disabled = false;
694
+ stopRecordButton.disabled = true;
695
+ return;
696
+ }
697
+ const audioBlob = new Blob(userAudioChunks, { type: mediaRecorder.mimeType });
698
+ if (websocket && websocket.readyState === WebSocket.OPEN) {
699
+ websocket.send(audioBlob);
700
+ statusDiv.textContent = 'Status: Audio sent. Waiting for response...';
701
+ } else {
702
+ statusDiv.textContent = 'Status: WebSocket not open. Cannot send audio.';
703
+ }
704
+ userAudioChunks = [];
705
+ };
706
+
707
+ mediaRecorder.start(250);
708
+ startRecordButton.disabled = true;
709
+ stopRecordButton.disabled = false;
710
+ statusDiv.textContent = 'Status: Recording...';
711
+ userTranscriptSpan.textContent = "...";
712
+ assistantTranscriptSpan.textContent = "...";
713
+ audioPlayer.style.display = 'none';
714
+ audioPlayer.src = '';
715
+ assistantAudioBufferQueue = [];
716
+ if (sourceNode) { try {sourceNode.stop();} catch(e){} sourceNode = null; }
717
+ } catch (err) {
718
+ console.error('Error accessing microphone:', err);
719
+ statusDiv.textContent = 'Status: Error accessing microphone.';
720
+ alert('Could not access microphone: ' + err.message);
721
+ }
722
+ };
723
+
724
+ stopRecordButton.onclick = () => {
725
+ if (mediaRecorder && mediaRecorder.state === "recording") {
726
+ mediaRecorder.stop();
727
+ startRecordButton.disabled = false;
728
+ stopRecordButton.disabled = true;
729
+ }
730
+ };
731
+
732
+ connectWebSocket();
733
+ </script>
734
+ </body>
735
+ </html>
736
+ """
737
+ return HTMLResponse(content=html_content)
738
+
739
+ if __name__ == "__main__":
740
+ uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
parler-streaming.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import math
3
+ from queue import Queue
4
+ from threading import Thread
5
+ from typing import Optional
6
+
7
+ import numpy as np
8
+ # import spaces
9
+ import gradio as gr
10
+ import torch
11
+
12
+ from parler_tts import ParlerTTSForConditionalGeneration
13
+ from pydub import AudioSegment
14
+ from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
15
+ from transformers.generation.streamers import BaseStreamer
16
+
17
+ device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
18
+ torch_dtype = torch.float16 if device != "cpu" else torch.float32
19
+
20
+ repo_id = "ai4bharat/indic-parler-tts"
21
+ jenny_repo_id = "ylacombe/parler-tts-mini-jenny-30H"
22
+
23
+ model = ParlerTTSForConditionalGeneration.from_pretrained(
24
+ repo_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
25
+ ).to(device)
26
+ # jenny_model = ParlerTTSForConditionalGeneration.from_pretrained(
27
+ # jenny_repo_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
28
+ # ).to(device)
29
+
30
+ tokenizer = AutoTokenizer.from_pretrained(repo_id)
31
+ feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
32
+
33
+ SAMPLE_RATE = feature_extractor.sampling_rate
34
+ SEED = 42
35
+
36
+ default_text = "Please surprise me and speak in whatever voice you enjoy."
37
+ examples = [
38
+ [
39
+ "Remember - this is only the first iteration of the model! To improve the prosody and naturalness of the speech further, we're scaling up the amount of training data by a factor of five times.",
40
+ "A male speaker with a low-pitched voice delivering his words at a fast pace in a small, confined space with a very clear audio and an animated tone.",
41
+ 3.0,
42
+ ],
43
+ [
44
+ "'This is the best time of my life, Bartley,' she said happily.",
45
+ "A female speaker with a slightly low-pitched, quite monotone voice delivers her words at a slightly faster-than-average pace in a confined space with very clear audio.",
46
+ 3.0,
47
+ ],
48
+ [
49
+ "Montrose also, after having experienced still more variety of good and bad fortune, threw down his arms, and retired out of the kingdom.",
50
+ "A male speaker with a slightly high-pitched voice delivering his words at a slightly slow pace in a small, confined space with a touch of background noise and a quite monotone tone.",
51
+ 3.0,
52
+ ],
53
+ [
54
+ "Montrose also, after having experienced still more variety of good and bad fortune, threw down his arms, and retired out of the kingdom.",
55
+ "A male speaker with a low-pitched voice delivers his words at a fast pace and an animated tone, in a very spacious environment, accompanied by noticeable background noise.",
56
+ 3.0,
57
+ ],
58
+ ]
59
+
60
+ jenny_examples = [
61
+ [
62
+ "Remember, this is only the first iteration of the model! To improve the prosody and naturalness of the speech further, we're scaling up the amount of training data by a factor of five times.",
63
+ "Jenny speaks at an average pace with a slightly animated delivery in a very confined sounding environment with clear audio quality.",
64
+ 3.0,
65
+ ],
66
+ [
67
+ "'This is the best time of my life, Bartley,' she said happily.",
68
+ "Jenny speaks in quite a monotone voice at a slightly faster-than-average pace in a confined space with very clear audio.",
69
+ 3.0,
70
+ ],
71
+ [
72
+ "Montrose also, after having experienced still more variety of good and bad fortune, threw down his arms, and retired out of the kingdom.",
73
+ "Jenny delivers her words at a slightly slow pace in a small, confined space with a touch of background noise and a quite monotone tone.",
74
+ 3.0,
75
+ ],
76
+ [
77
+ "Montrose also, after having experienced still more variety of good and bad fortune, threw down his arms, and retired out of the kingdom.",
78
+ "Jenny delivers her words at a fast pace and an animated tone, in a very spacious environment, accompanied by noticeable background noise.",
79
+ 3.0,
80
+ ],
81
+ ]
82
+
83
+
84
+ class ParlerTTSStreamer(BaseStreamer):
85
+ def __init__(
86
+ self,
87
+ model: ParlerTTSForConditionalGeneration,
88
+ device: Optional[str] = None,
89
+ play_steps: Optional[int] = 10,
90
+ stride: Optional[int] = None,
91
+ timeout: Optional[float] = None,
92
+ ):
93
+ """
94
+ Streamer that stores playback-ready audio in a queue, to be used by a downstream application as an iterator. This is
95
+ useful for applications that benefit from accessing the generated audio in a non-blocking way (e.g. in an interactive
96
+ Gradio demo).
97
+ Parameters:
98
+ model (`ParlerTTSForConditionalGeneration`):
99
+ The Parler-TTS model used to generate the audio waveform.
100
+ device (`str`, *optional*):
101
+ The torch device on which to run the computation. If `None`, will default to the device of the model.
102
+ play_steps (`int`, *optional*, defaults to 10):
103
+ The number of generation steps with which to return the generated audio array. Using fewer steps will
104
+ mean the first chunk is ready faster, but will require more codec decoding steps overall. This value
105
+ should be tuned to your device and latency requirements.
106
+ stride (`int`, *optional*):
107
+ The window (stride) between adjacent audio samples. Using a stride between adjacent audio samples reduces
108
+ the hard boundary between them, giving smoother playback. If `None`, will default to a value equivalent to
109
+ play_steps // 6 in the audio space.
110
+ timeout (`int`, *optional*):
111
+ The timeout for the audio queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
112
+ in `.generate()`, when it is called in a separate thread.
113
+ """
114
+ self.decoder = model.decoder
115
+ self.audio_encoder = model.audio_encoder
116
+ self.generation_config = model.generation_config
117
+ self.device = device if device is not None else model.device
118
+
119
+ # variables used in the streaming process
120
+ self.play_steps = play_steps
121
+ if stride is not None:
122
+ self.stride = stride
123
+ else:
124
+ hop_length = math.floor(self.audio_encoder.config.sampling_rate / self.audio_encoder.config.frame_rate)
125
+ self.stride = hop_length * (play_steps - self.decoder.num_codebooks) // 6
126
+ self.token_cache = None
127
+ self.to_yield = 0
128
+
129
+ # varibles used in the thread process
130
+ self.audio_queue = Queue()
131
+ self.stop_signal = None
132
+ self.timeout = timeout
133
+
134
+ def apply_delay_pattern_mask(self, input_ids):
135
+ # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler)
136
+ _, delay_pattern_mask = self.decoder.build_delay_pattern_mask(
137
+ input_ids[:, :1],
138
+ bos_token_id=self.generation_config.bos_token_id,
139
+ pad_token_id=self.generation_config.decoder_start_token_id,
140
+ max_length=input_ids.shape[-1],
141
+ )
142
+ # apply the pattern mask to the input ids
143
+ input_ids = self.decoder.apply_delay_pattern_mask(input_ids, delay_pattern_mask)
144
+
145
+ # revert the pattern delay mask by filtering the pad token id
146
+ mask = (delay_pattern_mask != self.generation_config.bos_token_id) & (delay_pattern_mask != self.generation_config.pad_token_id)
147
+ input_ids = input_ids[mask].reshape(1, self.decoder.num_codebooks, -1)
148
+ # append the frame dimension back to the audio codes
149
+ input_ids = input_ids[None, ...]
150
+
151
+ # send the input_ids to the correct device
152
+ input_ids = input_ids.to(self.audio_encoder.device)
153
+
154
+ decode_sequentially = (
155
+ self.generation_config.bos_token_id in input_ids
156
+ or self.generation_config.pad_token_id in input_ids
157
+ or self.generation_config.eos_token_id in input_ids
158
+ )
159
+ if not decode_sequentially:
160
+ output_values = self.audio_encoder.decode(
161
+ input_ids,
162
+ audio_scales=[None],
163
+ )
164
+ else:
165
+ sample = input_ids[:, 0]
166
+ sample_mask = (sample >= self.audio_encoder.config.codebook_size).sum(dim=(0, 1)) == 0
167
+ sample = sample[:, :, sample_mask]
168
+ output_values = self.audio_encoder.decode(sample[None, ...], [None])
169
+
170
+ audio_values = output_values.audio_values[0, 0]
171
+ return audio_values.cpu().float().numpy()
172
+
173
+ def put(self, value):
174
+ batch_size = value.shape[0] // self.decoder.num_codebooks
175
+ if batch_size > 1:
176
+ raise ValueError("ParlerTTSStreamer only supports batch size 1")
177
+
178
+ if self.token_cache is None:
179
+ self.token_cache = value
180
+ else:
181
+ self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1)
182
+
183
+ if self.token_cache.shape[-1] % self.play_steps == 0:
184
+ audio_values = self.apply_delay_pattern_mask(self.token_cache)
185
+ self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
186
+ self.to_yield += len(audio_values) - self.to_yield - self.stride
187
+
188
+ def end(self):
189
+ """Flushes any remaining cache and appends the stop symbol."""
190
+ if self.token_cache is not None:
191
+ audio_values = self.apply_delay_pattern_mask(self.token_cache)
192
+ else:
193
+ audio_values = np.zeros(self.to_yield)
194
+
195
+ self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True)
196
+
197
+ def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
198
+ """Put the new audio in the queue. If the stream is ending, also put a stop signal in the queue."""
199
+ self.audio_queue.put(audio, timeout=self.timeout)
200
+ if stream_end:
201
+ self.audio_queue.put(self.stop_signal, timeout=self.timeout)
202
+
203
+ def __iter__(self):
204
+ return self
205
+
206
+ def __next__(self):
207
+ value = self.audio_queue.get(timeout=self.timeout)
208
+ if not isinstance(value, np.ndarray) and value == self.stop_signal:
209
+ raise StopIteration()
210
+ else:
211
+ return value
212
+
213
+ def numpy_to_mp3(audio_array, sampling_rate):
214
+ # Normalize audio_array if it's floating-point
215
+ if np.issubdtype(audio_array.dtype, np.floating):
216
+ max_val = np.max(np.abs(audio_array))
217
+ audio_array = (audio_array / max_val) * 32767 # Normalize to 16-bit range
218
+ audio_array = audio_array.astype(np.int16)
219
+
220
+ # Create an audio segment from the numpy array
221
+ audio_segment = AudioSegment(
222
+ audio_array.tobytes(),
223
+ frame_rate=sampling_rate,
224
+ sample_width=audio_array.dtype.itemsize,
225
+ channels=1
226
+ )
227
+
228
+ # Export the audio segment to MP3 bytes - use a high bitrate to maximise quality
229
+ mp3_io = io.BytesIO()
230
+ audio_segment.export(mp3_io, format="mp3", bitrate="320k")
231
+
232
+ # Get the MP3 bytes
233
+ mp3_bytes = mp3_io.getvalue()
234
+ mp3_io.close()
235
+
236
+ return mp3_bytes
237
+
238
+ sampling_rate = model.audio_encoder.config.sampling_rate
239
+ frame_rate = model.audio_encoder.config.frame_rate
240
+
241
+ # @spaces.GPU
242
+ def generate_base(text, description, play_steps_in_s=2.0):
243
+ play_steps = int(frame_rate * play_steps_in_s)
244
+ streamer = ParlerTTSStreamer(model, device=device, play_steps=play_steps)
245
+
246
+ inputs = tokenizer(description, return_tensors="pt").to(device)
247
+ prompt = tokenizer(text, return_tensors="pt").to(device)
248
+
249
+ generation_kwargs = dict(
250
+ input_ids=inputs.input_ids,
251
+ prompt_input_ids=prompt.input_ids,
252
+ streamer=streamer,
253
+ do_sample=True,
254
+ temperature=1.0,
255
+ min_new_tokens=10,
256
+ )
257
+
258
+ set_seed(SEED)
259
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
260
+ thread.start()
261
+
262
+ for new_audio in streamer:
263
+ print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
264
+ yield numpy_to_mp3(new_audio, sampling_rate=sampling_rate)
265
+
266
+ # @spaces.GPU
267
+ def generate_jenny(text, description, play_steps_in_s=2.0):
268
+ play_steps = int(frame_rate * play_steps_in_s)
269
+ streamer = ParlerTTSStreamer(model, device=device, play_steps=play_steps)
270
+
271
+ inputs = tokenizer(description, return_tensors="pt").to(device)
272
+ prompt = tokenizer(text, return_tensors="pt").to(device)
273
+
274
+ generation_kwargs = dict(
275
+ input_ids=inputs.input_ids,
276
+ prompt_input_ids=prompt.input_ids,
277
+ streamer=streamer,
278
+ do_sample=True,
279
+ temperature=1.0,
280
+ min_new_tokens=10,
281
+ )
282
+
283
+ set_seed(SEED)
284
+ thread = Thread(target=jenny_model.generate, kwargs=generation_kwargs)
285
+ thread.start()
286
+
287
+ for new_audio in streamer:
288
+ print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
289
+ yield sampling_rate, new_audio
290
+
291
+
292
+ css = """
293
+ #share-btn-container {
294
+ display: flex;
295
+ padding-left: 0.5rem !important;
296
+ padding-right: 0.5rem !important;
297
+ background-color: #000000;
298
+ justify-content: center;
299
+ align-items: center;
300
+ border-radius: 9999px !important;
301
+ width: 13rem;
302
+ margin-top: 10px;
303
+ margin-left: auto;
304
+ flex: unset !important;
305
+ }
306
+ #share-btn {
307
+ all: initial;
308
+ color: #ffffff;
309
+ font-weight: 600;
310
+ cursor: pointer;
311
+ font-family: 'IBM Plex Sans', sans-serif;
312
+ margin-left: 0.5rem !important;
313
+ padding-top: 0.25rem !important;
314
+ padding-bottom: 0.25rem !important;
315
+ right:0;
316
+ }
317
+ #share-btn * {
318
+ all: unset !important;
319
+ }
320
+ #share-btn-container div:nth-child(-n+2){
321
+ width: auto !important;
322
+ min-height: 0px !important;
323
+ }
324
+ #share-btn-container .wrap {
325
+ display: none !important;
326
+ }
327
+ """
328
+ with gr.Blocks(css=css) as block:
329
+ gr.HTML(
330
+ """
331
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
332
+ <div
333
+ style="
334
+ display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;
335
+ "
336
+ >
337
+ <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
338
+ Parler-TTS 🗣️
339
+ </h1>
340
+ </div>
341
+ </div>
342
+ """
343
+ )
344
+ gr.HTML(
345
+ f"""
346
+ <p><a href="https://github.com/huggingface/parler-tts"> Parler-TTS</a> is a training and inference library for
347
+ high-fidelity text-to-speech (TTS) models. Two models are demonstrated here, <a href="https://huggingface.co/parler-tts/parler_tts_mini_v0.1"> Parler-TTS Mini v0.1</a>,
348
+ is the first iteration model trained using 10k hours of narrated audiobooks, and <a href="https://huggingface.co/ylacombe/parler-tts-mini-jenny-30H"> Parler-TTS Jenny</a>,
349
+ a model fine-tuned on the <a href="https://huggingface.co/datasets/reach-vb/jenny_tts_dataset"> Jenny dataset</a>.
350
+ Both models generates high-quality speech with features that can be controlled using a simple text prompt (e.g. gender, background noise, speaking rate, pitch and reverberation).</p>
351
+ <p>Tips for ensuring good generation:
352
+ <ul>
353
+ <li>Include the term <b>"very clear audio"</b> to generate the highest quality audio, and "very noisy audio" for high levels of background noise</li>
354
+ <li>When using the fine-tuned model, include the term <b>"Jenny"</b> to pick out her voice</li>
355
+ <li>Punctuation can be used to control the prosody of the generations, e.g. use commas to add small breaks in speech</li>
356
+ <li>The remaining speech features (gender, speaking rate, pitch and reverberation) can be controlled directly through the prompt</li>
357
+ </ul>
358
+ </p>
359
+ """
360
+ )
361
+ with gr.Tab("Base"):
362
+ with gr.Row():
363
+ with gr.Column():
364
+ input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text")
365
+ description = gr.Textbox(label="Description", lines=2, value="", elem_id="input_description")
366
+ play_seconds = gr.Slider(3.0, 7.0, value=3.0, step=2, label="Streaming interval in seconds", info="Lower = shorter chunks, lower latency, more codec steps")
367
+ run_button = gr.Button("Generate Audio", variant="primary")
368
+ with gr.Column():
369
+ audio_out = gr.Audio(label="Parler-TTS generation", format="mp3", elem_id="audio_out", streaming=True, autoplay=True)
370
+
371
+ inputs = [input_text, description, play_seconds]
372
+ outputs = [audio_out]
373
+ gr.Examples(examples=examples, fn=generate_base, inputs=inputs, outputs=outputs, cache_examples=False)
374
+ run_button.click(fn=generate_base, inputs=inputs, outputs=outputs, queue=True)
375
+
376
+ with gr.Tab("Jenny"):
377
+ with gr.Row():
378
+ with gr.Column():
379
+ input_text = gr.Textbox(label="Input Text", lines=2, value=jenny_examples[0][0], elem_id="input_text")
380
+ description = gr.Textbox(label="Description", lines=2, value=jenny_examples[0][1], elem_id="input_description")
381
+ play_seconds = gr.Slider(3.0, 7.0, value=jenny_examples[0][2], step=2, label="Streaming interval in seconds", info="Lower = shorter chunks, lower latency, more codec steps")
382
+ run_button = gr.Button("Generate Audio", variant="primary")
383
+ with gr.Column():
384
+ audio_out = gr.Audio(label="Parler-TTS generation", format="mp3", elem_id="audio_out", streaming=True, autoplay=True)
385
+
386
+ inputs = [input_text, description, play_seconds]
387
+ outputs = [audio_out]
388
+ gr.Examples(examples=jenny_examples, fn=generate_jenny, inputs=inputs, outputs=outputs, cache_examples=False)
389
+ run_button.click(fn=generate_jenny, inputs=inputs, outputs=outputs, queue=True)
390
+
391
+ gr.HTML(
392
+ """
393
+ <p>To improve the prosody and naturalness of the speech further, we're scaling up the amount of training data to 50k hours of speech.
394
+ The v1 release of the model will be trained on this data, as well as inference optimisations, such as flash attention
395
+ and torch compile, that will improve the latency by 2-4x. If you want to find out more about how this model was trained and even fine-tune it yourself, check-out the
396
+ <a href="https://github.com/huggingface/parler-tts"> Parler-TTS</a> repository on GitHub. The Parler-TTS codebase and its
397
+ associated checkpoints are licensed under <a href='https://github.com/huggingface/parler-tts?tab=Apache-2.0-1-ov-file#readme'> Apache 2.0</a>.</p>
398
+ """
399
+ )
400
+
401
+ block.queue()
402
+ block.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # requirements.txt
2
+ fastapi
3
+ uvicorn[standard]
4
+ websockets
5
+ openai-whisper
6
+ torch
7
+ torchaudio
8
+ transformers
9
+ accelerate # Often useful for transformers
10
+ python-multipart # For file uploads in traditional endpoints
11
+ soundfile # For handling audio files
12
+ librosa
13
+ parler-tts # For AI4Bharat's IndicParler-TTS
14
+ onnx
15
+ onnxruntime
16
+ # For specific hardware acceleration (optional, choose based on your setup)
17
+ # bitsandbytes # For 8-bit quantization of LLM (further RAM reduction)
18
+ # sentencepiece # Often a dependency for tokenizers
19
+ # For demo purposes
20
+ gradio
21
+ # For Gemini integration
22
+ google-generativeai
streaming_nb.ipynb ADDED
File without changes
test_notebook.ipynb ADDED
The diff for this file is too large to render. See raw diff