Upload folder using huggingface_hub
Browse files- .gradio/certificate.pem +31 -0
- README.md +108 -8
- __pycache__/main.cpython-310.pyc +0 -0
- gradio_app.py +294 -0
- infereless.py +14 -0
- main.py +740 -0
- parler-streaming.py +402 -0
- requirements.txt +22 -0
- streaming_nb.ipynb +0 -0
- test_notebook.ipynb +0 -0
.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
README.md
CHANGED
@@ -1,12 +1,112 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: indigo
|
5 |
-
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: voice-assistant
|
3 |
+
app_file: gradio_app.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
+
sdk_version: 5.29.1
|
|
|
|
|
6 |
---
|
7 |
+
# Real-time Conversational AI Chatbot Backend
|
8 |
|
9 |
+
This project implements a Python-based backend for a real-time conversational AI chatbot. It features Speech-to-Text (STT), Language Model (LLM) processing via Google's Gemini API, and streaming Text-to-Speech (TTS) capabilities, all orchestrated through a FastAPI web server with WebSocket support for interactive conversations.
|
10 |
+
|
11 |
+
## Core Features
|
12 |
+
|
13 |
+
- **Speech-to-Text (STT):** Utilizes OpenAI's Whisper model to transcribe user's spoken audio into text.
|
14 |
+
- **Language Model (LLM):** Integrates with Google's Gemini API (e.g., `gemini-1.5-flash-latest`) for generating intelligent and contextual responses.
|
15 |
+
- **Text-to-Speech (TTS) with Streaming:** Employs AI4Bharat's IndicParler-TTS model (via `parler-tts` library) with `ParlerTTSStreamer` to convert the LLM's text response into audible speech, streamed chunk by chunk for faster time-to-first-audio.
|
16 |
+
- **Real-time Interaction:** A WebSocket endpoint (`/ws/conversation`) manages the live, bidirectional flow of audio and text data between the client and server.
|
17 |
+
- **Component Testing:** Includes individual HTTP RESTful endpoints for testing STT, LLM, and TTS functionalities separately.
|
18 |
+
- **Basic Client Demo:** Provides a simple HTML/JavaScript client served at the root (`/`) for demonstrating the WebSocket conversation flow.
|
19 |
+
|
20 |
+
## Technologies Used
|
21 |
+
|
22 |
+
- **Backend Framework:** FastAPI
|
23 |
+
- **ASR (STT):** OpenAI Whisper
|
24 |
+
- **LLM:** Google Gemini API (via `google-generativeai` SDK)
|
25 |
+
- **TTS:** AI4Bharat IndicParler-TTS (via `parler-tts` and `transformers`)
|
26 |
+
- **Audio Processing:** `soundfile`, `librosa`
|
27 |
+
- **Async & Concurrency:** `asyncio`, `threading` (for ParlerTTSStreamer)
|
28 |
+
- **ML/DL:** PyTorch
|
29 |
+
- **Web Server:** Uvicorn
|
30 |
+
|
31 |
+
## Setup and Installation
|
32 |
+
|
33 |
+
1. **Clone the Repository (if applicable)**
|
34 |
+
|
35 |
+
```bash
|
36 |
+
git clone <your-repo-url>
|
37 |
+
cd <your-repo-name>
|
38 |
+
```
|
39 |
+
|
40 |
+
2. **Create a Python Virtual Environment**
|
41 |
+
|
42 |
+
- Using `venv`:
|
43 |
+
```bash
|
44 |
+
python -m venv venv
|
45 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
46 |
+
```
|
47 |
+
- Or using `conda`:
|
48 |
+
```bash
|
49 |
+
conda create -n voicebot_env python=3.10 # Or your preferred Python 3.9+
|
50 |
+
conda activate voicebot_env
|
51 |
+
```
|
52 |
+
|
53 |
+
3. **Install Dependencies**
|
54 |
+
|
55 |
+
```bash
|
56 |
+
pip install -r requirements.txt
|
57 |
+
```
|
58 |
+
|
59 |
+
Ensure you have `ffmpeg` installed on your system, as Whisper requires it.
|
60 |
+
(e.g., `sudo apt update && sudo apt install ffmpeg` on Debian/Ubuntu)
|
61 |
+
|
62 |
+
4. **Set Environment Variables:**
|
63 |
+
- **Gemini API Key:** Obtain an API key from [Google AI Studio](https://aistudio.google.com/). Set it as an environment variable:
|
64 |
+
```bash
|
65 |
+
export GEMINI_API_KEY="YOUR_ACTUAL_GEMINI_API_KEY"
|
66 |
+
```
|
67 |
+
(For Windows PowerShell: `$env:GEMINI_API_KEY="YOUR_ACTUAL_GEMINI_API_KEY"`)
|
68 |
+
- **(Optional) Whisper Model Size:**
|
69 |
+
```bash
|
70 |
+
export WHISPER_MODEL_SIZE="base" # (e.g., tiny, base, small, medium, large)
|
71 |
+
```
|
72 |
+
Defaults to "base" if not set.
|
73 |
+
|
74 |
+
### HTTP RESTful Endpoints
|
75 |
+
|
76 |
+
These are standard FastAPI path operations for testing individual components:
|
77 |
+
|
78 |
+
- **`POST /api/stt`**: Upload an audio file to get its transcription.
|
79 |
+
- **`POST /api/llm`**: Send text in a JSON payload to get a response from Gemini.
|
80 |
+
- **`POST /api/tts`**: Send text in a JSON payload to get synthesized audio (non-streaming for this HTTP endpoint, returns base64 encoded WAV).
|
81 |
+
|
82 |
+
### WebSocket Endpoint: `/ws/conversation`
|
83 |
+
|
84 |
+
This is the primary endpoint for real-time, bidirectional conversational interaction:
|
85 |
+
|
86 |
+
- `@app.websocket("/ws/conversation")` defines the WebSocket route.
|
87 |
+
- **Connection Handling:** Accepts new WebSocket connections.
|
88 |
+
- **Main Interaction Loop:**
|
89 |
+
1. **Receive Audio:** Waits to receive audio data (bytes) from the client (`await websocket.receive_bytes()`).
|
90 |
+
2. **STT:** Calls `transcribe_audio_bytes()` to get text from the user's audio. Sends `USER_TRANSCRIPT: <text>` back to the client.
|
91 |
+
3. **LLM:** Calls `generate_gemini_response()` with the transcribed text. Sends `ASSISTANT_RESPONSE_TEXT: <text>` back to the client.
|
92 |
+
4. **Streaming TTS:**
|
93 |
+
- Sends a `TTS_STREAM_START: {<audio_params>}` message to the client, informing it about the sample rate, channels, and bit depth of the upcoming audio stream.
|
94 |
+
- Iterates through the `synthesize_speech_streaming()` asynchronous generator.
|
95 |
+
- For each `audio_chunk_bytes` yielded, it sends these raw audio bytes to the client using `await websocket.send_bytes()`.
|
96 |
+
- If `websocket.send_bytes()` fails (e.g., client disconnected), the loop breaks, and the `cancellation_event` is set to signal the TTS thread.
|
97 |
+
- After the stream is complete (or cancelled), it sends a `TTS_STREAM_END` message.
|
98 |
+
- **Error Handling:** Includes `try...except WebSocketDisconnect` to handle client disconnections gracefully and a general exception handler.
|
99 |
+
- **Cleanup:** The `finally` block ensures the `cancellation_event` for TTS is set and attempts to close the WebSocket.
|
100 |
+
|
101 |
+
## How to Run
|
102 |
+
|
103 |
+
1. Ensure all setup steps (environment, dependencies, API key) are complete.
|
104 |
+
2. Execute the script:
|
105 |
+
```bash
|
106 |
+
python main.py
|
107 |
+
```
|
108 |
+
Or, for development with auto-reload:
|
109 |
+
```bash
|
110 |
+
uvicorn main:app --reload --host 0.0.0.0 --port 8000
|
111 |
+
```
|
112 |
+
3. The server will start, and you should see logs indicating that models are being loaded.
|
__pycache__/main.cpython-310.pyc
ADDED
Binary file (31.2 kB). View file
|
|
gradio_app.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# gradio_app.py
|
2 |
+
import gradio as gr
|
3 |
+
import io
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
from parler_tts import ParlerTTSForConditionalGeneration
|
7 |
+
from transformers import AutoTokenizer, AutoModel # CHANGED: Using AutoModel as per model card
|
8 |
+
import numpy as np
|
9 |
+
import google.generativeai as genai
|
10 |
+
import asyncio
|
11 |
+
import librosa
|
12 |
+
import torchaudio # Often used by models like this for audio loading/processing internally or as input type
|
13 |
+
|
14 |
+
# --- Configuration ---
|
15 |
+
ASR_MODEL_NAME = "ai4bharat/indic-conformer-600m-multilingual"
|
16 |
+
TARGET_SAMPLE_RATE = 16000 # Model expects 16kHz
|
17 |
+
|
18 |
+
TTS_MODEL_NAME = "ai4bharat/indic-parler-tts"
|
19 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "AIzaSyD6x3Yoby4eQ6QL2kaaG_Rz3fG3rh7wPB8")
|
20 |
+
GEMINI_MODEL_NAME_GRADIO = "gemini-1.5-flash-latest"
|
21 |
+
|
22 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
23 |
+
# torch_dtype for ParlerTTS, Gemini etc. For ASR model, it might handle its own precision.
|
24 |
+
|
25 |
+
# --- Global Model Variables ---
|
26 |
+
asr_model_gradio = None # This will be the AutoModel instance
|
27 |
+
|
28 |
+
gemini_model_instance_gradio = None
|
29 |
+
tts_model_gradio = None
|
30 |
+
tts_tokenizer_gradio = None # For ParlerTTS
|
31 |
+
|
32 |
+
# --- Model Loading & API Configuration ---
|
33 |
+
def load_all_resources_gradio():
|
34 |
+
global asr_model_gradio, tts_model_gradio, tts_tokenizer_gradio, gemini_model_instance_gradio
|
35 |
+
print(f"Gradio: Loading resources. ASR will be on device: {DEVICE}")
|
36 |
+
|
37 |
+
if asr_model_gradio is None:
|
38 |
+
print(f"Gradio: Loading ASR model: {ASR_MODEL_NAME} using AutoModel")
|
39 |
+
try:
|
40 |
+
# Load using AutoModel as per the model card's implication
|
41 |
+
asr_model_gradio = AutoModel.from_pretrained(ASR_MODEL_NAME, trust_remote_code=True)
|
42 |
+
asr_model_gradio.to(DEVICE) # Move model to device
|
43 |
+
# The model might handle its own precision (e.g. .half()) internally if `trust_remote_code` allows
|
44 |
+
# Or you might need to call asr_model_gradio.half() if it supports it and you're on CUDA.
|
45 |
+
if DEVICE == "cuda" and hasattr(asr_model_gradio, 'half'):
|
46 |
+
print("Gradio: Applying .half() to ASR model.")
|
47 |
+
asr_model_gradio.half()
|
48 |
+
asr_model_gradio.eval()
|
49 |
+
print(f"Gradio: ASR model ({ASR_MODEL_NAME}) loaded using AutoModel.")
|
50 |
+
except Exception as e:
|
51 |
+
print(f"Gradio: Failed to load ASR model {ASR_MODEL_NAME} using AutoModel: {e}")
|
52 |
+
import traceback
|
53 |
+
traceback.print_exc()
|
54 |
+
asr_model_gradio = None
|
55 |
+
|
56 |
+
if tts_model_gradio is None: # ParlerTTS loading
|
57 |
+
print(f"Gradio: Loading IndicParler-TTS model: {TTS_MODEL_NAME}")
|
58 |
+
# Ensure ParlerTTS specific tokenizer is loaded for TTS
|
59 |
+
# Note: ASR model might have its own internal tokenizer/processor handled by its custom code
|
60 |
+
tts_parler_tokenizer = AutoTokenizer.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True)
|
61 |
+
tts_model_gradio = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True).to(DEVICE)
|
62 |
+
tts_tokenizer_gradio = tts_parler_tokenizer
|
63 |
+
print("Gradio: IndicParler-TTS model loaded.")
|
64 |
+
|
65 |
+
if gemini_model_instance_gradio is None: # Gemini loading
|
66 |
+
if not GEMINI_API_KEY:
|
67 |
+
print("Gradio: GEMINI_API_KEY not found. LLM functionality via Gemini will be limited.")
|
68 |
+
else:
|
69 |
+
try:
|
70 |
+
genai.configure(api_key=GEMINI_API_KEY)
|
71 |
+
gemini_model_instance_gradio = genai.GenerativeModel(GEMINI_MODEL_NAME_GRADIO)
|
72 |
+
print(f"Gradio: Gemini API configured with model: {GEMINI_MODEL_NAME_GRADIO}")
|
73 |
+
except Exception as e:
|
74 |
+
print(f"Gradio: Failed to configure Gemini API: {e}")
|
75 |
+
gemini_model_instance_gradio = None
|
76 |
+
|
77 |
+
print("Gradio: All resources loaded (or attempted).")
|
78 |
+
|
79 |
+
|
80 |
+
# --- Helper Functions ---
|
81 |
+
def transcribe_audio_gradio(audio_input_tuple):
|
82 |
+
if asr_model_gradio is None:
|
83 |
+
return f"Error: ASR model ({ASR_MODEL_NAME}) not loaded."
|
84 |
+
|
85 |
+
if audio_input_tuple is None:
|
86 |
+
print("Gradio: No audio provided to transcribe_audio_gradio.")
|
87 |
+
return "No audio provided."
|
88 |
+
|
89 |
+
sample_rate, audio_numpy = audio_input_tuple
|
90 |
+
|
91 |
+
if audio_numpy is None or audio_numpy.size == 0:
|
92 |
+
print("Gradio: Audio numpy array is empty.")
|
93 |
+
return "Empty audio received."
|
94 |
+
|
95 |
+
# Ensure audio is mono float32, which is a common expectation
|
96 |
+
if audio_numpy.ndim > 1:
|
97 |
+
if audio_numpy.shape[0] == 2 and audio_numpy.ndim == 2:
|
98 |
+
audio_numpy = librosa.to_mono(audio_numpy)
|
99 |
+
elif audio_numpy.shape[1] == 2 and audio_numpy.ndim == 2:
|
100 |
+
audio_numpy = np.mean(audio_numpy, axis=1)
|
101 |
+
|
102 |
+
if audio_numpy.dtype != np.float32:
|
103 |
+
if np.issubdtype(audio_numpy.dtype, np.integer):
|
104 |
+
audio_numpy = audio_numpy.astype(np.float32) / np.iinfo(audio_numpy.dtype).max
|
105 |
+
else:
|
106 |
+
audio_numpy = audio_numpy.astype(np.float32)
|
107 |
+
|
108 |
+
# Resample to TARGET_SAMPLE_RATE (16kHz)
|
109 |
+
if sample_rate != TARGET_SAMPLE_RATE:
|
110 |
+
print(f"Gradio: Resampling audio from {sample_rate} Hz to {TARGET_SAMPLE_RATE} Hz.")
|
111 |
+
try:
|
112 |
+
audio_numpy = librosa.resample(y=audio_numpy, orig_sr=sample_rate, target_sr=TARGET_SAMPLE_RATE)
|
113 |
+
# After resampling, the audio_numpy is at TARGET_SAMPLE_RATE
|
114 |
+
except Exception as e:
|
115 |
+
print(f"Gradio: Error during resampling: {e}")
|
116 |
+
return f"Error during audio resampling: {str(e)}"
|
117 |
+
|
118 |
+
try:
|
119 |
+
print(f"Gradio: Preparing to transcribe with {ASR_MODEL_NAME}. Input audio shape: {audio_numpy.shape}")
|
120 |
+
|
121 |
+
# The model card example `model(wav, "hi", "ctc")` implies it might take a waveform tensor.
|
122 |
+
# We have a numpy array. We need to convert it to a PyTorch tensor.
|
123 |
+
# The model card uses torchaudio.load which returns a tensor.
|
124 |
+
# Let's convert our numpy array to a tensor and ensure it's on the correct device.
|
125 |
+
|
126 |
+
# Ensure the audio_numpy is 1D as expected by many ASR models for a single channel
|
127 |
+
if audio_numpy.ndim > 1:
|
128 |
+
audio_numpy = audio_numpy.squeeze() # Attempt to remove singleton dimensions
|
129 |
+
if audio_numpy.ndim > 1 : # If still more than 1D, problem
|
130 |
+
print(f"Gradio: Audio numpy array for ASR has unexpected dimensions after processing: {audio_numpy.shape}")
|
131 |
+
return "Error: Audio processing resulted in unexpected dimensions."
|
132 |
+
|
133 |
+
wav_tensor = torch.from_numpy(audio_numpy).to(DEVICE)
|
134 |
+
# The model might expect a batch dimension, e.g., [1, num_samples]
|
135 |
+
if wav_tensor.ndim == 1:
|
136 |
+
wav_tensor = wav_tensor.unsqueeze(0)
|
137 |
+
|
138 |
+
print(f"Gradio: Transcribing with {ASR_MODEL_NAME} using CTC. Input tensor shape: {wav_tensor.shape}")
|
139 |
+
|
140 |
+
# Perform ASR with CTC decoding (you can choose "rnnt" if preferred and supported)
|
141 |
+
# The language code "hi" is for Hindi. You might want to make this configurable
|
142 |
+
# or see if the model supports language auto-detection if you pass None or omit it.
|
143 |
+
# For now, assuming "hi" or that the model handles mixed language if lang_id is not strictly enforced.
|
144 |
+
# The model card doesn't specify if language ID is optional or how auto-detection works.
|
145 |
+
# Let's try "auto" or a common language like "en" or "hi" to start.
|
146 |
+
# The model card indicates training on 22 languages, so it's multilingual.
|
147 |
+
# If language_id is required, you'll need to provide it.
|
148 |
+
# Let's assume for now we try with a common Indian language or let the model try to auto-detect if "auto" or None is valid.
|
149 |
+
# The snippet "model(wav, "hi", "ctc")" is specific.
|
150 |
+
|
151 |
+
# The `model()` call is synchronous. Gradio handles this in a thread.
|
152 |
+
with torch.no_grad(): # Good practice for inference
|
153 |
+
transcription_result = asr_model_gradio(wav_tensor, "hi", "ctc") # Using lang_id="hi" and strategy="ctc" as per example
|
154 |
+
|
155 |
+
# The output format needs to be checked. The model card implies it's the transcribed string directly.
|
156 |
+
# It might be a list of transcriptions if batching occurs, or a dict.
|
157 |
+
if isinstance(transcription_result, list) and len(transcription_result) > 0:
|
158 |
+
transcribed_text = transcription_result[0] # Assuming first result for non-batched input
|
159 |
+
elif isinstance(transcription_result, str):
|
160 |
+
transcribed_text = transcription_result
|
161 |
+
else:
|
162 |
+
print(f"Gradio: Unexpected ASR result format: {type(transcription_result)}, value: {transcription_result}")
|
163 |
+
transcribed_text = "ASR result format not recognized."
|
164 |
+
|
165 |
+
transcribed_text = transcribed_text.strip()
|
166 |
+
print(f"Gradio: Transcription ({ASR_MODEL_NAME}, CTC): {transcribed_text}")
|
167 |
+
return transcribed_text if transcribed_text else "Transcription resulted in empty text."
|
168 |
+
except Exception as e:
|
169 |
+
print(f"Gradio: Error during {ASR_MODEL_NAME} transcription (AutoModel callable): {e}")
|
170 |
+
import traceback
|
171 |
+
traceback.print_exc()
|
172 |
+
return f"Error during transcription ({ASR_MODEL_NAME}): {str(e)}"
|
173 |
+
|
174 |
+
|
175 |
+
# ... (Gemini LLM and TTS functions remain the same) ...
|
176 |
+
def generate_gemini_response_gradio(text_input: str):
|
177 |
+
if not gemini_model_instance_gradio:
|
178 |
+
return "Error: Gemini LLM not configured or API key missing."
|
179 |
+
if not isinstance(text_input, str) or not text_input.strip() or text_input.startswith("Error:") or "No audio provided" in text_input or "Transcription resulted in empty text" in text_input or "Empty audio received" in text_input or "ASR result format not recognized" in text_input:
|
180 |
+
print(f"Gradio: Invalid input to Gemini: '{text_input}'. Skipping LLM response.")
|
181 |
+
return "LLM (Gemini) skipped due to transcription issue or no input."
|
182 |
+
try:
|
183 |
+
print(f"Gradio: Sending to Gemini: '{text_input}'")
|
184 |
+
full_prompt = f"User: {text_input}\nAssistant:"
|
185 |
+
response = gemini_model_instance_gradio.generate_content(full_prompt)
|
186 |
+
response_text = ""
|
187 |
+
if response.candidates and response.candidates[0].content.parts:
|
188 |
+
response_text = response.candidates[0].content.parts[0].text.strip()
|
189 |
+
else:
|
190 |
+
feedback_info = ""
|
191 |
+
if hasattr(response, 'prompt_feedback') and response.prompt_feedback:
|
192 |
+
feedback_info = f" Feedback: {response.prompt_feedback}"
|
193 |
+
print(f"Gradio: Gemini response did not contain expected content.{feedback_info}")
|
194 |
+
response_text = f"I'm sorry, I couldn't generate a response for that (Gemini).{feedback_info}"
|
195 |
+
|
196 |
+
print(f"Gradio: Gemini LLM Response: {response_text}")
|
197 |
+
return response_text if response_text else "Gemini LLM generated an empty response."
|
198 |
+
except Exception as e:
|
199 |
+
print(f"Gradio: Error during Gemini LLM generation: {e}")
|
200 |
+
import traceback
|
201 |
+
traceback.print_exc()
|
202 |
+
return f"Error during Gemini LLM generation: {str(e)}"
|
203 |
+
|
204 |
+
def synthesize_speech_gradio(text_input: str, description: str = "A clear, female voice speaking in English."):
|
205 |
+
if tts_model_gradio is None or tts_tokenizer_gradio is None:
|
206 |
+
return "Error: TTS model or its tokenizer not loaded."
|
207 |
+
if not isinstance(text_input, str) or not text_input.strip() or text_input.startswith("Error:") or "LLM skipped" in text_input or "generated an empty response" in text_input or "not configured" in text_input or "ASR result format not recognized" in text_input :
|
208 |
+
print(f"Gradio: Invalid input to TTS: '{text_input}'. Skipping synthesis.")
|
209 |
+
return "TTS skipped due to LLM issue or no input."
|
210 |
+
try:
|
211 |
+
print(f"Gradio: Synthesizing speech for: '{text_input}'")
|
212 |
+
description_tokenized = tts_tokenizer_gradio(description, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
213 |
+
description_ids = description_tokenized.input_ids.to(DEVICE)
|
214 |
+
description_attention_mask = description_tokenized.attention_mask.to(DEVICE)
|
215 |
+
|
216 |
+
prompt_tokenized = tts_tokenizer_gradio(text_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
217 |
+
prompt_ids = prompt_tokenized.input_ids.to(DEVICE)
|
218 |
+
|
219 |
+
if prompt_ids.shape[-1] == 0: # Check if tokenized prompt is empty
|
220 |
+
print(f"Gradio: Tokenized prompt for TTS is empty. Text was: '{text_input}'. Skipping synthesis.")
|
221 |
+
return "TTS skipped: Input text resulted in empty tokens."
|
222 |
+
|
223 |
+
|
224 |
+
generation = tts_model_gradio.generate(
|
225 |
+
input_ids=description_ids,
|
226 |
+
attention_mask=description_attention_mask,
|
227 |
+
prompt_input_ids=prompt_ids,
|
228 |
+
do_sample=True, temperature=0.7, top_k=50, top_p=0.95
|
229 |
+
).cpu().numpy().squeeze()
|
230 |
+
|
231 |
+
sampling_rate = tts_model_gradio.config.sampling_rate
|
232 |
+
print(f"Gradio: Speech synthesized. Array shape: {generation.shape}, Sample rate: {sampling_rate}")
|
233 |
+
return (sampling_rate, generation)
|
234 |
+
except Exception as e:
|
235 |
+
print(f"Gradio: Error during speech synthesis: {e}")
|
236 |
+
import traceback
|
237 |
+
traceback.print_exc()
|
238 |
+
if "You need to specify either `text` or `text_target`" in str(e):
|
239 |
+
return "Error in TTS: Model requires 'text' or 'text_target'. Input might be too short or problematic."
|
240 |
+
return f"Error during speech synthesis: {str(e)}"
|
241 |
+
|
242 |
+
# --- Gradio Interface Definition ---
|
243 |
+
load_all_resources_gradio()
|
244 |
+
|
245 |
+
def full_pipeline_gradio(audio_input):
|
246 |
+
transcribed_text_output = transcribe_audio_gradio(audio_input)
|
247 |
+
print(f"DEBUG full_pipeline_gradio - Step 1 (Transcription): '{transcribed_text_output}' (type: {type(transcribed_text_output)})")
|
248 |
+
llm_response_text_output = generate_gemini_response_gradio(transcribed_text_output)
|
249 |
+
print(f"DEBUG full_pipeline_gradio - Step 2 (LLM Response): '{llm_response_text_output}' (type: {type(llm_response_text_output)})")
|
250 |
+
tts_synthesis_result = synthesize_speech_gradio(llm_response_text_output)
|
251 |
+
final_audio_output = None
|
252 |
+
if isinstance(tts_synthesis_result, tuple) and len(tts_synthesis_result) == 2 and isinstance(tts_synthesis_result[1], np.ndarray):
|
253 |
+
final_audio_output = tts_synthesis_result
|
254 |
+
print(f"DEBUG full_pipeline_gradio - Step 3 (TTS Success): Audio tuple with shape {tts_synthesis_result[1].shape if isinstance(tts_synthesis_result[1], np.ndarray) else 'N/A'}")
|
255 |
+
else:
|
256 |
+
error_message_from_tts = str(tts_synthesis_result) if isinstance(tts_synthesis_result, str) else "TTS synthesis failed or returned unexpected type"
|
257 |
+
print(f"DEBUG full_pipeline_gradio - Step 3 (TTS Failed/Non-audio): {error_message_from_tts}. Providing silent audio.")
|
258 |
+
# Append TTS error to LLM text only if LLM text was valid
|
259 |
+
if llm_response_text_output and not llm_response_text_output.startswith("Error:") and "LLM skipped" not in llm_response_text_output and "ASR result format not recognized" not in llm_response_text_output:
|
260 |
+
llm_response_text_output = f"{llm_response_text_output} | (TTS Problem: {error_message_from_tts})"
|
261 |
+
elif not llm_response_text_output or llm_response_text_output.startswith("Error:") or "LLM skipped" in llm_response_text_output or "ASR result format not recognized" in llm_response_text_output:
|
262 |
+
# If LLM already had an error, just keep that error, maybe note TTS also had an issue
|
263 |
+
llm_response_text_output = f"{llm_response_text_output} (TTS also had an issue: {error_message_from_tts})"
|
264 |
+
|
265 |
+
default_sample_rate = tts_model_gradio.config.sampling_rate if tts_model_gradio and hasattr(tts_model_gradio, 'config') else TARGET_SAMPLE_RATE
|
266 |
+
final_audio_output = (default_sample_rate, np.array([0.0], dtype=np.float32))
|
267 |
+
print(f"DEBUG full_pipeline_gradio - Step 3 (TTS Fallback): Silent audio tuple")
|
268 |
+
print(f"DEBUG full_pipeline_gradio - RETURNING: Transcription='{transcribed_text_output}', LLM_Text='{llm_response_text_output}', Audio_Type={type(final_audio_output)}")
|
269 |
+
return transcribed_text_output, llm_response_text_output, final_audio_output
|
270 |
+
|
271 |
+
with gr.Blocks(title="Conversational AI Demo") as demo:
|
272 |
+
gr.Markdown("# Conversational AI Demo (STT -> Gemini LLM -> TTS)")
|
273 |
+
with gr.Row():
|
274 |
+
audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Speak Here")
|
275 |
+
process_button = gr.Button("Process Audio")
|
276 |
+
with gr.Accordion("Outputs", open=True):
|
277 |
+
transcription_out = gr.Textbox(label="You Said (Transcription)", lines=2)
|
278 |
+
llm_response_out = gr.Textbox(label="Gemini Assistant Says (Text)", lines=5)
|
279 |
+
audio_out = gr.Audio(label="Assistant Says (Audio)")
|
280 |
+
|
281 |
+
process_button.click(
|
282 |
+
fn=full_pipeline_gradio,
|
283 |
+
inputs=[audio_in],
|
284 |
+
outputs=[transcription_out, llm_response_out, audio_out]
|
285 |
+
)
|
286 |
+
gr.Markdown("---")
|
287 |
+
gr.Markdown("### How to Use:")
|
288 |
+
gr.Markdown("1. Ensure your `GEMINI_API_KEY` environment variable is set.")
|
289 |
+
gr.Markdown("2. Click into the 'Speak Here' box and record your audio.")
|
290 |
+
gr.Markdown("3. Click the 'Process Audio' button.")
|
291 |
+
gr.Markdown("4. View the transcription, Gemini's text response, and listen to the audio response.")
|
292 |
+
|
293 |
+
if __name__ == "__main__":
|
294 |
+
demo.launch(share=False)
|
infereless.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import json
|
3 |
+
URL = 'https://serverless-region-v1.inferless.com/api/v1/parler-tts-streaming-1_ae4e81bb5d604799b573df3f0b3c9518/infer'
|
4 |
+
headers = {"Content-Type": "application/json", "Authorization": "Bearer 1e01145781a0639d830555d5e4e4e5e1752726750db75e995e0e246f32c4b7c9f442bd6f8caec8acc6b9684ec78e5b633db04370815ca1748bf5a7db80245411"}
|
5 |
+
|
6 |
+
data = json.loads('''{
|
7 |
+
"parameters": {
|
8 |
+
"prompt_value": "A male speaker with a low-pitched voice delivering his words at a fast pace in a small, confined space with a very clear audio and an animated tone.",
|
9 |
+
"input_value": "Remember - this is only the first iteration of the model! To improve the prosody and naturalness of the speech further, we're scaling up the amount of training data by a factor of five times."
|
10 |
+
}
|
11 |
+
}''')
|
12 |
+
|
13 |
+
response = requests.post(URL, headers=headers, data=json.dumps(data))
|
14 |
+
print(response.json())
|
main.py
ADDED
@@ -0,0 +1,740 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# main.py
|
2 |
+
import asyncio
|
3 |
+
import base64
|
4 |
+
import io
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
from threading import Thread, Event # Added Event for better thread control
|
8 |
+
import time # For timeout checks
|
9 |
+
|
10 |
+
import soundfile as sf
|
11 |
+
import torch
|
12 |
+
import uvicorn
|
13 |
+
import whisper
|
14 |
+
from fastapi import FastAPI, File, UploadFile, WebSocket, WebSocketDisconnect
|
15 |
+
from fastapi.responses import HTMLResponse, JSONResponse
|
16 |
+
from fastapi.middleware.cors import CORSMiddleware
|
17 |
+
from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
|
18 |
+
from transformers import AutoTokenizer, GenerationConfig # Keep transformers.GenerationConfig
|
19 |
+
import google.generativeai as genai
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
# --- Configuration ---
|
23 |
+
WHISPER_MODEL_SIZE = os.getenv("WHISPER_MODEL_SIZE", "tiny")
|
24 |
+
TTS_MODEL_NAME = "ai4bharat/indic-parler-tts"
|
25 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "AIzaSyD6x3Yoby4eQ6QL2kaaG_Rz3fG3rh7wPB8")
|
26 |
+
GEMINI_MODEL_NAME = "gemini-1.5-flash-latest"
|
27 |
+
|
28 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
29 |
+
|
30 |
+
attn_implementation = "flash_attention_2" if torch.cuda.is_available() else "eager"
|
31 |
+
|
32 |
+
torch_dtype_tts = torch.bfloat16 if DEVICE == "cuda" and torch.cuda.is_bf16_supported() else (torch.float16 if DEVICE == "cuda" else torch.float32)
|
33 |
+
torch_dtype_whisper = torch.float16 if DEVICE == "cuda" else torch.float32
|
34 |
+
|
35 |
+
TTS_DEFAULT_PARAMS = {
|
36 |
+
"do_sample": True,
|
37 |
+
"temperature": 1.0,
|
38 |
+
"top_k": 50,
|
39 |
+
"top_p": 0.95,
|
40 |
+
"min_new_tokens": 5, # Reduced for quicker start with streamer
|
41 |
+
# "max_new_tokens": 256, # Optional global cap
|
42 |
+
}
|
43 |
+
|
44 |
+
# --- Logging ---
|
45 |
+
logging.basicConfig(level=logging.INFO)
|
46 |
+
logger = logging.getLogger(__name__)
|
47 |
+
|
48 |
+
# --- FastAPI App Initialization ---
|
49 |
+
app = FastAPI(title="Conversational AI Chatbot with Enhanced Stream Abortion")
|
50 |
+
|
51 |
+
app.add_middleware(
|
52 |
+
CORSMiddleware,
|
53 |
+
allow_origins=["*"],
|
54 |
+
allow_credentials=True,
|
55 |
+
allow_methods=["*"],
|
56 |
+
allow_headers=["*"],
|
57 |
+
)
|
58 |
+
|
59 |
+
# --- Global Model Variables ---
|
60 |
+
whisper_model = None
|
61 |
+
gemini_model_instance = None
|
62 |
+
tts_model = None
|
63 |
+
tts_tokenizer = None
|
64 |
+
# We will build the GenerationConfig object from TTS_DEFAULT_PARAMS inside the functions
|
65 |
+
# or store it globally if preferred, initialized from transformers.GenerationConfig
|
66 |
+
|
67 |
+
# --- Model Loading & API Configuration ---
|
68 |
+
@app.on_event("startup")
|
69 |
+
async def load_resources():
|
70 |
+
global whisper_model, tts_model, tts_tokenizer, gemini_model_instance
|
71 |
+
logger.info(f"Loading local models. Whisper on {DEVICE} with {torch_dtype_whisper}, TTS on {DEVICE} with {torch_dtype_tts}")
|
72 |
+
|
73 |
+
try:
|
74 |
+
logger.info(f"Loading Whisper model: {WHISPER_MODEL_SIZE}")
|
75 |
+
whisper_model = whisper.load_model(WHISPER_MODEL_SIZE, device=DEVICE)
|
76 |
+
logger.info("Whisper model loaded successfully.")
|
77 |
+
|
78 |
+
logger.info(f"Loading IndicParler-TTS model: {TTS_MODEL_NAME}")
|
79 |
+
tts_model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL_NAME, attn_implementation=attn_implementation).to(DEVICE, dtype=torch_dtype_tts)
|
80 |
+
tts_tokenizer = AutoTokenizer.from_pretrained(TTS_MODEL_NAME)
|
81 |
+
|
82 |
+
if tts_tokenizer:
|
83 |
+
if tts_tokenizer.pad_token_id is not None:
|
84 |
+
TTS_DEFAULT_PARAMS["pad_token_id"] = tts_tokenizer.pad_token_id
|
85 |
+
# ParlerTTS uses a special token_id for silence, not eos_token_id for generation end.
|
86 |
+
# eos_token_id is more for text models.
|
87 |
+
# if tts_tokenizer.eos_token_id is not None:
|
88 |
+
# TTS_DEFAULT_PARAMS["eos_token_id"] = tts_tokenizer.eos_token_id
|
89 |
+
logger.info(f"IndicParler-TTS model loaded. Default generation params: {TTS_DEFAULT_PARAMS}")
|
90 |
+
|
91 |
+
if not GEMINI_API_KEY:
|
92 |
+
logger.warning("GEMINI_API_KEY not found. LLM functionality will be limited.")
|
93 |
+
else:
|
94 |
+
try:
|
95 |
+
genai.configure(api_key=GEMINI_API_KEY)
|
96 |
+
gemini_model_instance = genai.GenerativeModel(GEMINI_MODEL_NAME)
|
97 |
+
logger.info(f"Gemini API configured with model: {GEMINI_MODEL_NAME}")
|
98 |
+
except Exception as e:
|
99 |
+
logger.error(f"Failed to configure Gemini API: {e}", exc_info=True)
|
100 |
+
gemini_model_instance = None
|
101 |
+
|
102 |
+
except Exception as e:
|
103 |
+
logger.error(f"Error loading models: {e}", exc_info=True)
|
104 |
+
logger.info("Local models and API configurations loaded.")
|
105 |
+
|
106 |
+
|
107 |
+
# --- Helper Functions ---
|
108 |
+
async def transcribe_audio_bytes(audio_bytes: bytes) -> str:
|
109 |
+
if not whisper_model:
|
110 |
+
raise RuntimeError("Whisper model not loaded.")
|
111 |
+
temp_audio_path = f"temp_audio_main_{os.urandom(4).hex()}.wav"
|
112 |
+
try:
|
113 |
+
with open(temp_audio_path, "wb") as f:
|
114 |
+
f.write(audio_bytes)
|
115 |
+
result = whisper_model.transcribe(temp_audio_path, fp16=(DEVICE == "cuda" and torch_dtype_whisper == torch.float16))
|
116 |
+
transcribed_text = result["text"].strip()
|
117 |
+
logger.info(f"Transcription: {transcribed_text}")
|
118 |
+
return transcribed_text
|
119 |
+
except Exception as e:
|
120 |
+
logger.error(f"Error during transcription: {e}", exc_info=True)
|
121 |
+
return ""
|
122 |
+
finally:
|
123 |
+
if os.path.exists(temp_audio_path):
|
124 |
+
try:
|
125 |
+
os.remove(temp_audio_path)
|
126 |
+
except Exception as e_del:
|
127 |
+
logger.error(f"Error deleting temp audio file {temp_audio_path}: {e_del}")
|
128 |
+
|
129 |
+
|
130 |
+
async def generate_gemini_response(text: str) -> str:
|
131 |
+
if not gemini_model_instance:
|
132 |
+
logger.error("Gemini model instance not available.")
|
133 |
+
return "Sorry, the language model is currently unavailable."
|
134 |
+
try:
|
135 |
+
full_prompt = f"User: {text}\nAssistant:"
|
136 |
+
loop = asyncio.get_event_loop()
|
137 |
+
response = await loop.run_in_executor(None, gemini_model_instance.generate_content, full_prompt)
|
138 |
+
response_text = "I'm sorry, I couldn't generate a response for that."
|
139 |
+
if hasattr(response, 'text') and response.text: # For simple text responses
|
140 |
+
response_text = response.text.strip()
|
141 |
+
elif response.parts: # New way to access parts for gemini-1.5-flash and pro
|
142 |
+
response_text = "".join(part.text for part in response.parts).strip()
|
143 |
+
elif response.candidates and response.candidates[0].content.parts: # Older way
|
144 |
+
response_text = response.candidates[0].content.parts[0].text.strip()
|
145 |
+
else:
|
146 |
+
safety_feedback = ""
|
147 |
+
if hasattr(response, 'prompt_feedback') and response.prompt_feedback:
|
148 |
+
safety_feedback = f" Safety Feedback: {response.prompt_feedback}"
|
149 |
+
elif response.candidates and hasattr(response.candidates[0], 'finish_reason') and response.candidates[0].finish_reason != "STOP":
|
150 |
+
safety_feedback = f" Finish Reason: {response.candidates[0].finish_reason}"
|
151 |
+
logger.warning(f"Gemini response might be empty or blocked.{safety_feedback}")
|
152 |
+
logger.info(f"Gemini LLM Response: {response_text}")
|
153 |
+
return response_text
|
154 |
+
except Exception as e:
|
155 |
+
logger.error(f"Error during Gemini LLM generation: {e}", exc_info=True)
|
156 |
+
return "Sorry, I encountered an error trying to respond."
|
157 |
+
|
158 |
+
|
159 |
+
async def synthesize_speech_streaming(text: str, description: str = "A clear, female voice speaking in English.", play_steps_in_s: float = 0.4, cancellation_event: Event = Event()):
|
160 |
+
if not tts_model or not tts_tokenizer:
|
161 |
+
logger.error("TTS model or tokenizer not loaded.")
|
162 |
+
if cancellation_event and cancellation_event.is_set(): logger.info("TTS cancelled before start."); yield b""; return
|
163 |
+
yield b""
|
164 |
+
return
|
165 |
+
|
166 |
+
if not text or not text.strip():
|
167 |
+
logger.warning("TTS input text is empty. Yielding empty audio.")
|
168 |
+
if cancellation_event and cancellation_event.is_set(): logger.info("TTS cancelled before start (empty text)."); yield b""; return
|
169 |
+
yield b""
|
170 |
+
return
|
171 |
+
|
172 |
+
streamer = None
|
173 |
+
thread = None
|
174 |
+
|
175 |
+
try:
|
176 |
+
logger.info(f"Starting TTS streaming with ParlerTTSStreamer for: \"{text[:50]}...\"")
|
177 |
+
|
178 |
+
# Ensure sampling_rate is correctly accessed from the model's config
|
179 |
+
# For ParlerTTS, it's usually under model.config.audio_encoder.sampling_rate
|
180 |
+
if hasattr(tts_model.config, 'audio_encoder') and hasattr(tts_model.config.audio_encoder, 'sampling_rate'):
|
181 |
+
sampling_rate = tts_model.config.audio_encoder.sampling_rate
|
182 |
+
else:
|
183 |
+
logger.warning("Could not find tts_model.config.audio_encoder.sampling_rate, defaulting to 24000")
|
184 |
+
sampling_rate = 24000 # A common default for ParlerTTS if not found
|
185 |
+
|
186 |
+
try:
|
187 |
+
frame_rate = getattr(tts_model.config.audio_encoder, 'frame_rate', 100)
|
188 |
+
except AttributeError:
|
189 |
+
logger.warning("frame_rate not found in tts_model.config.audio_encoder. Using default of 100 Hz for play_steps calculation.")
|
190 |
+
frame_rate = 100
|
191 |
+
|
192 |
+
play_steps = int(frame_rate * play_steps_in_s)
|
193 |
+
if play_steps == 0 : play_steps = 1
|
194 |
+
|
195 |
+
logger.info(f"Streamer params: sampling_rate={sampling_rate}, frame_rate={frame_rate}, play_steps_in_s={play_steps_in_s}, play_steps={play_steps}")
|
196 |
+
|
197 |
+
streamer = ParlerTTSStreamer(tts_model, device=DEVICE, play_steps=play_steps)
|
198 |
+
|
199 |
+
description_inputs = tts_tokenizer(description, return_tensors="pt")
|
200 |
+
prompt_inputs = tts_tokenizer(text, return_tensors="pt")
|
201 |
+
|
202 |
+
gen_config_dict = TTS_DEFAULT_PARAMS.copy()
|
203 |
+
# ParlerTTS generate method might not take a GenerationConfig object directly,
|
204 |
+
# but rather individual kwargs. The streamer example passes them as kwargs.
|
205 |
+
# We ensure pad_token_id and eos_token_id are set if the tokenizer has them.
|
206 |
+
if tts_tokenizer.pad_token_id is not None:
|
207 |
+
gen_config_dict["pad_token_id"] = tts_tokenizer.pad_token_id
|
208 |
+
# ParlerTTS might not use eos_token_id in the same way as text models.
|
209 |
+
# if tts_tokenizer.eos_token_id is not None:
|
210 |
+
# gen_config_dict["eos_token_id"] = tts_tokenizer.eos_token_id
|
211 |
+
|
212 |
+
|
213 |
+
thread_generation_kwargs = {
|
214 |
+
"input_ids": description_inputs.input_ids.to(DEVICE),
|
215 |
+
"prompt_input_ids": prompt_inputs.input_ids.to(DEVICE),
|
216 |
+
"attention_mask": description_inputs.attention_mask.to(DEVICE) if hasattr(description_inputs, 'attention_mask') else None,
|
217 |
+
"streamer": streamer,
|
218 |
+
**gen_config_dict # Spread the generation parameters
|
219 |
+
}
|
220 |
+
if thread_generation_kwargs["attention_mask"] is None:
|
221 |
+
del thread_generation_kwargs["attention_mask"]
|
222 |
+
|
223 |
+
def _generate_in_thread():
|
224 |
+
try:
|
225 |
+
logger.info(f"TTS generation thread started.")
|
226 |
+
with torch.no_grad():
|
227 |
+
tts_model.generate(**thread_generation_kwargs)
|
228 |
+
logger.info("TTS generation thread finished model.generate().")
|
229 |
+
except Exception as e_thread:
|
230 |
+
logger.error(f"Error in TTS generation thread: {e_thread}", exc_info=True)
|
231 |
+
finally:
|
232 |
+
if streamer: streamer.end()
|
233 |
+
logger.info("TTS generation thread called streamer.end().")
|
234 |
+
|
235 |
+
thread = Thread(target=_generate_in_thread)
|
236 |
+
thread.daemon = True
|
237 |
+
thread.start()
|
238 |
+
|
239 |
+
loop = asyncio.get_event_loop()
|
240 |
+
while True:
|
241 |
+
if cancellation_event and cancellation_event.is_set():
|
242 |
+
logger.info("TTS streaming cancelled by event.")
|
243 |
+
break
|
244 |
+
|
245 |
+
try:
|
246 |
+
# Run the blocking streamer.__next__() in an executor
|
247 |
+
audio_chunk_tensor = await loop.run_in_executor(None, streamer.__next__)
|
248 |
+
|
249 |
+
if audio_chunk_tensor is None:
|
250 |
+
logger.info("Streamer yielded None explicitly, ending stream.")
|
251 |
+
break
|
252 |
+
# This check for numel == 0 is important as streamer might yield empty tensors
|
253 |
+
if not isinstance(audio_chunk_tensor, torch.Tensor) or audio_chunk_tensor.numel() == 0:
|
254 |
+
# REMOVED: if streamer.is_done(): (AttributeError)
|
255 |
+
# Instead, rely on StopIteration or explicit None from streamer
|
256 |
+
await asyncio.sleep(0.01) # Small sleep if empty but not done
|
257 |
+
continue
|
258 |
+
|
259 |
+
audio_chunk_np = audio_chunk_tensor.cpu().to(torch.float32).numpy().squeeze()
|
260 |
+
if audio_chunk_np.size == 0:
|
261 |
+
continue
|
262 |
+
|
263 |
+
audio_chunk_int16 = np.clip(audio_chunk_np * 32767, -32768, 32767).astype(np.int16)
|
264 |
+
yield audio_chunk_int16.tobytes()
|
265 |
+
# No need for sleep here if chunks are substantial, client will process
|
266 |
+
# await asyncio.sleep(0.001) # Can be removed or made very small
|
267 |
+
|
268 |
+
except StopIteration:
|
269 |
+
logger.info("Streamer finished (StopIteration).")
|
270 |
+
break
|
271 |
+
except Exception as e_stream_iter:
|
272 |
+
logger.error(f"Error iterating streamer: {e_stream_iter}", exc_info=True)
|
273 |
+
break
|
274 |
+
|
275 |
+
logger.info(f"Finished TTS streaming iteration for: \"{text[:50]}...\"")
|
276 |
+
|
277 |
+
except Exception as e:
|
278 |
+
logger.error(f"Error in synthesize_speech_streaming function: {e}", exc_info=True)
|
279 |
+
yield b""
|
280 |
+
finally:
|
281 |
+
logger.info("Exiting synthesize_speech_streaming. Ensuring streamer is ended and thread is joined.")
|
282 |
+
if streamer:
|
283 |
+
streamer.end()
|
284 |
+
if thread and thread.is_alive():
|
285 |
+
logger.info("Waiting for TTS generation thread to complete in finally block...")
|
286 |
+
final_join_start_time = time.time()
|
287 |
+
thread.join(timeout=2.0)
|
288 |
+
if thread.is_alive():
|
289 |
+
logger.warning(f"TTS generation thread still alive after {time.time() - final_join_start_time:.2f}s in finally block.")
|
290 |
+
|
291 |
+
|
292 |
+
# --- FastAPI HTTP Endpoints ---
|
293 |
+
@app.post("/api/stt", summary="Speech to Text")
|
294 |
+
async def speech_to_text_endpoint(file: UploadFile = File(...)):
|
295 |
+
if not whisper_model:
|
296 |
+
return JSONResponse(content={"error": "Whisper model not loaded"}, status_code=503)
|
297 |
+
try:
|
298 |
+
audio_bytes = await file.read()
|
299 |
+
transcribed_text = await transcribe_audio_bytes(audio_bytes)
|
300 |
+
return {"transcription": transcribed_text}
|
301 |
+
except Exception as e:
|
302 |
+
return JSONResponse(content={"error": str(e)}, status_code=500)
|
303 |
+
|
304 |
+
@app.post("/api/llm", summary="LLM Response Generation (Gemini)")
|
305 |
+
async def llm_endpoint(payload: dict):
|
306 |
+
if not gemini_model_instance:
|
307 |
+
return JSONResponse(content={"error": "Gemini LLM not configured or API key missing"}, status_code=503)
|
308 |
+
try:
|
309 |
+
text = payload.get("text")
|
310 |
+
if not text:
|
311 |
+
return JSONResponse(content={"error": "No text provided"}, status_code=400)
|
312 |
+
response = await generate_gemini_response(text)
|
313 |
+
return {"response": response}
|
314 |
+
except Exception as e:
|
315 |
+
return JSONResponse(content={"error": str(e)}, status_code=500)
|
316 |
+
|
317 |
+
@app.post("/api/tts", summary="Text to Speech (Non-Streaming for HTTP)")
|
318 |
+
async def text_to_speech_endpoint(payload: dict):
|
319 |
+
if not tts_model or not tts_tokenizer:
|
320 |
+
return JSONResponse(content={"error": "TTS model/tokenizer not loaded"}, status_code=503)
|
321 |
+
try:
|
322 |
+
text = payload.get("text")
|
323 |
+
description = payload.get("description", "A clear, female voice speaking in English.")
|
324 |
+
if not text:
|
325 |
+
return JSONResponse(content={"error": "No text provided"}, status_code=400)
|
326 |
+
|
327 |
+
description_inputs = tts_tokenizer(description, return_tensors="pt")
|
328 |
+
prompt_inputs = tts_tokenizer(text, return_tensors="pt")
|
329 |
+
|
330 |
+
# Use a GenerationConfig object for clarity and consistency
|
331 |
+
gen_config_dict = TTS_DEFAULT_PARAMS.copy()
|
332 |
+
if tts_tokenizer.pad_token_id is not None:
|
333 |
+
gen_config_dict["pad_token_id"] = tts_tokenizer.pad_token_id
|
334 |
+
# if tts_tokenizer.eos_token_id is not None: # ParlerTTS might not use standard eos
|
335 |
+
# gen_config_dict["eos_token_id"] = tts_tokenizer.eos_token_id
|
336 |
+
|
337 |
+
# Create GenerationConfig from transformers
|
338 |
+
generation_config_obj = GenerationConfig(**gen_config_dict)
|
339 |
+
|
340 |
+
|
341 |
+
with torch.no_grad():
|
342 |
+
generation = tts_model.generate(
|
343 |
+
input_ids=description_inputs.input_ids.to(DEVICE),
|
344 |
+
prompt_input_ids=prompt_inputs.input_ids.to(DEVICE),
|
345 |
+
attention_mask=description_inputs.attention_mask.to(DEVICE) if hasattr(description_inputs, 'attention_mask') else None,
|
346 |
+
generation_config=generation_config_obj # Pass the config object
|
347 |
+
).cpu().to(torch.float32).numpy().squeeze()
|
348 |
+
|
349 |
+
audio_io = io.BytesIO()
|
350 |
+
scaled_generation = np.clip(generation * 32767, -32768, 32767).astype(np.int16)
|
351 |
+
|
352 |
+
current_sampling_rate = tts_model.config.audio_encoder.sampling_rate if hasattr(tts_model.config, 'audio_encoder') else 24000
|
353 |
+
sf.write(audio_io, scaled_generation, samplerate=current_sampling_rate, format='WAV', subtype='PCM_16')
|
354 |
+
audio_io.seek(0)
|
355 |
+
audio_bytes = audio_io.read()
|
356 |
+
|
357 |
+
if not audio_bytes:
|
358 |
+
return JSONResponse(content={"error": "TTS failed to generate audio"}, status_code=500)
|
359 |
+
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
|
360 |
+
return {"audio_base64": audio_base64, "format": "wav", "sample_rate": current_sampling_rate}
|
361 |
+
except Exception as e:
|
362 |
+
logger.error(f"TTS endpoint error: {e}", exc_info=True)
|
363 |
+
return JSONResponse(content={"error": str(e)}, status_code=500)
|
364 |
+
|
365 |
+
# --- WebSocket Endpoint for Real-time Conversation ---
|
366 |
+
@app.websocket("/ws/conversation")
|
367 |
+
async def conversation_websocket(websocket: WebSocket):
|
368 |
+
await websocket.accept()
|
369 |
+
logger.info(f"WebSocket connection accepted from: {websocket.client}")
|
370 |
+
|
371 |
+
tts_cancellation_event = Event() # For this specific connection
|
372 |
+
|
373 |
+
try:
|
374 |
+
while True:
|
375 |
+
if websocket.client_state.name != 'CONNECTED': # Check if client disconnected before receive
|
376 |
+
logger.info(f"WebSocket client {websocket.client} disconnected before receive.")
|
377 |
+
break
|
378 |
+
|
379 |
+
audio_data = await websocket.receive_bytes()
|
380 |
+
logger.info(f"Received {len(audio_data)} bytes of user audio data from {websocket.client}.")
|
381 |
+
|
382 |
+
if not audio_data:
|
383 |
+
logger.warning(f"Received empty audio data from user {websocket.client}.")
|
384 |
+
continue
|
385 |
+
|
386 |
+
transcribed_text = await transcribe_audio_bytes(audio_data)
|
387 |
+
if not transcribed_text:
|
388 |
+
logger.warning(f"Transcription failed for {websocket.client}.")
|
389 |
+
await websocket.send_text("SYSTEM_ERROR: Transcription failed.")
|
390 |
+
continue
|
391 |
+
await websocket.send_text(f"USER_TRANSCRIPT: {transcribed_text}")
|
392 |
+
|
393 |
+
llm_response_text = await generate_gemini_response(transcribed_text)
|
394 |
+
if not llm_response_text or "Sorry, I encountered an error" in llm_response_text or "unavailable" in llm_response_text:
|
395 |
+
logger.warning(f"LLM (Gemini) failed for {websocket.client}: {llm_response_text}")
|
396 |
+
await websocket.send_text(f"SYSTEM_ERROR: LLM failed. ({llm_response_text})")
|
397 |
+
continue
|
398 |
+
await websocket.send_text(f"ASSISTANT_RESPONSE_TEXT: {llm_response_text}")
|
399 |
+
|
400 |
+
tts_description = "A clear, female voice speaking in English."
|
401 |
+
|
402 |
+
current_sampling_rate = tts_model.config.audio_encoder.sampling_rate if hasattr(tts_model.config, 'audio_encoder') else 24000
|
403 |
+
audio_params_msg = f"TTS_STREAM_START:{{\"sample_rate\": {current_sampling_rate}, \"channels\": 1, \"bit_depth\": 16}}"
|
404 |
+
await websocket.send_text(audio_params_msg)
|
405 |
+
logger.info(f"Sent to client {websocket.client}: {audio_params_msg}")
|
406 |
+
|
407 |
+
chunk_count = 0
|
408 |
+
tts_cancellation_event.clear() # Reset event for new TTS task
|
409 |
+
|
410 |
+
async for audio_chunk_bytes in synthesize_speech_streaming(llm_response_text, tts_description, cancellation_event=tts_cancellation_event):
|
411 |
+
if not audio_chunk_bytes:
|
412 |
+
logger.debug(f"Received empty bytes from streaming generator for {websocket.client}, might be end or error in generator.")
|
413 |
+
continue
|
414 |
+
try:
|
415 |
+
if websocket.client_state.name != 'CONNECTED':
|
416 |
+
logger.warning(f"Client {websocket.client} disconnected during TTS stream. Aborting TTS.")
|
417 |
+
tts_cancellation_event.set() # Signal TTS thread to stop
|
418 |
+
break
|
419 |
+
await websocket.send_bytes(audio_chunk_bytes)
|
420 |
+
chunk_count += 1
|
421 |
+
except Exception as send_err:
|
422 |
+
logger.warning(f"Error sending audio chunk to {websocket.client}: {send_err}. Client likely disconnected.")
|
423 |
+
tts_cancellation_event.set() # Signal TTS thread to stop
|
424 |
+
break
|
425 |
+
|
426 |
+
if not tts_cancellation_event.is_set(): # Only send END if not cancelled
|
427 |
+
logger.info(f"Sent {chunk_count} TTS audio chunks to client {websocket.client}.")
|
428 |
+
await websocket.send_text("TTS_STREAM_END")
|
429 |
+
logger.info(f"Sent TTS_STREAM_END to client {websocket.client}.")
|
430 |
+
else:
|
431 |
+
logger.info(f"TTS stream for {websocket.client} was cancelled. Sent {chunk_count} chunks before cancellation.")
|
432 |
+
|
433 |
+
|
434 |
+
except WebSocketDisconnect:
|
435 |
+
logger.info(f"WebSocket connection closed by client {websocket.client}.")
|
436 |
+
tts_cancellation_event.set() # Signal any active TTS to stop
|
437 |
+
except Exception as e:
|
438 |
+
logger.error(f"Error in WebSocket conversation with {websocket.client}: {e}", exc_info=True)
|
439 |
+
tts_cancellation_event.set() # Signal any active TTS to stop
|
440 |
+
try:
|
441 |
+
if websocket.client_state.name == 'CONNECTED':
|
442 |
+
await websocket.send_text(f"SYSTEM_ERROR: An unexpected error occurred: {str(e)}")
|
443 |
+
except Exception: pass
|
444 |
+
finally:
|
445 |
+
logger.info(f"Cleaning up WebSocket connection for {websocket.client}.")
|
446 |
+
tts_cancellation_event.set() # Ensure event is set on any exit path
|
447 |
+
if websocket.client_state.name == 'CONNECTED' or websocket.client_state.name == 'CONNECTING':
|
448 |
+
try: await websocket.close()
|
449 |
+
except Exception: pass
|
450 |
+
logger.info(f"WebSocket connection resources cleaned up for {websocket.client}.")
|
451 |
+
|
452 |
+
# ... (HTML serving and main execution block remain the same) ...
|
453 |
+
@app.get("/", response_class=HTMLResponse)
|
454 |
+
async def get_home():
|
455 |
+
html_content = """
|
456 |
+
<!DOCTYPE html>
|
457 |
+
<html>
|
458 |
+
<head>
|
459 |
+
<title>Conversational AI Chatbot (Streaming)</title>
|
460 |
+
<style>
|
461 |
+
body { font-family: Arial, sans-serif; margin: 20px; background-color: #f4f4f4; color: #333; }
|
462 |
+
#chatbox { width: 80%; max-width: 600px; margin: auto; background-color: #fff; padding: 20px; box-shadow: 0 0 10px rgba(0,0,0,0.1); border-radius: 8px; }
|
463 |
+
.message { padding: 10px; margin-bottom: 10px; border-radius: 5px; }
|
464 |
+
.user { background-color: #e1f5fe; text-align: right; }
|
465 |
+
.assistant { background-color: #f1f8e9; }
|
466 |
+
.system { background-color: #ffebee; color: #c62828; font-style: italic;}
|
467 |
+
#audioPlayerContainer { margin-top: 10px; }
|
468 |
+
#audioPlayer { display: none; width: 100%; }
|
469 |
+
button { padding: 10px 15px; background-color: #007bff; color: white; border: none; border-radius: 5px; cursor: pointer; margin-top:10px; }
|
470 |
+
button:disabled { background-color: #ccc; }
|
471 |
+
#status { margin-top: 10px; font-style: italic; color: #666; }
|
472 |
+
#transcriptionArea, #llmResponseArea { margin-top: 10px; padding: 5px; border: 1px solid #eee; background: #fafafa; word-wrap: break-word;}
|
473 |
+
</style>
|
474 |
+
</head>
|
475 |
+
<body>
|
476 |
+
<div id="chatbox">
|
477 |
+
<h2>Real-time AI Chatbot (Streaming TTS)</h2>
|
478 |
+
<div id="messages"></div>
|
479 |
+
<div id="transcriptionArea"><strong>You (transcribed):</strong> <span id="userTranscript">...</span></div>
|
480 |
+
<div id="llmResponseArea"><strong>Assistant (text):</strong> <span id="assistantTranscript">...</span></div>
|
481 |
+
|
482 |
+
<button id="startRecordButton">Start Recording</button>
|
483 |
+
<button id="stopRecordButton" disabled>Stop Recording</button>
|
484 |
+
<p id="status">Status: Idle</p>
|
485 |
+
<div id="audioPlayerContainer">
|
486 |
+
<audio id="audioPlayer" controls></audio>
|
487 |
+
</div>
|
488 |
+
</div>
|
489 |
+
|
490 |
+
<script>
|
491 |
+
const startRecordButton = document.getElementById('startRecordButton');
|
492 |
+
const stopRecordButton = document.getElementById('stopRecordButton');
|
493 |
+
const audioPlayer = document.getElementById('audioPlayer');
|
494 |
+
const messagesDiv = document.getElementById('messages');
|
495 |
+
const statusDiv = document.getElementById('status');
|
496 |
+
const userTranscriptSpan = document.getElementById('userTranscript');
|
497 |
+
const assistantTranscriptSpan = document.getElementById('assistantTranscript');
|
498 |
+
|
499 |
+
let websocket;
|
500 |
+
let mediaRecorder;
|
501 |
+
let userAudioChunks = [];
|
502 |
+
|
503 |
+
let assistantAudioBufferQueue = [];
|
504 |
+
let audioContext;
|
505 |
+
let expectedSampleRate;
|
506 |
+
let ttsStreaming = false;
|
507 |
+
let audioPlaying = false;
|
508 |
+
let sourceNode = null;
|
509 |
+
|
510 |
+
function initAudioContext() {
|
511 |
+
if (!audioContext || audioContext.state === 'closed') {
|
512 |
+
try {
|
513 |
+
audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
514 |
+
console.log("AudioContext initialized or re-initialized.");
|
515 |
+
} catch (e) {
|
516 |
+
console.error("Web Audio API is not supported in this browser.", e);
|
517 |
+
addMessage("Error: Web Audio API not supported. Cannot play streamed audio.", "system");
|
518 |
+
audioContext = null;
|
519 |
+
}
|
520 |
+
}
|
521 |
+
}
|
522 |
+
|
523 |
+
|
524 |
+
function connectWebSocket() {
|
525 |
+
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
526 |
+
const wsUrl = `${protocol}//${window.location.host}/ws/conversation`;
|
527 |
+
websocket = new WebSocket(wsUrl);
|
528 |
+
websocket.binaryType = 'arraybuffer';
|
529 |
+
|
530 |
+
websocket.onopen = () => {
|
531 |
+
statusDiv.textContent = 'Status: Connected. Ready to record.';
|
532 |
+
startRecordButton.disabled = false;
|
533 |
+
initAudioContext();
|
534 |
+
};
|
535 |
+
|
536 |
+
websocket.onmessage = (event) => {
|
537 |
+
if (event.data instanceof ArrayBuffer) {
|
538 |
+
if (ttsStreaming && audioContext && expectedSampleRate) {
|
539 |
+
const pcmDataInt16 = new Int16Array(event.data);
|
540 |
+
if (pcmDataInt16.length > 0) {
|
541 |
+
assistantAudioBufferQueue.push(pcmDataInt16);
|
542 |
+
playNextChunkFromQueue();
|
543 |
+
}
|
544 |
+
} else {
|
545 |
+
console.warn("Received ArrayBuffer data but not in TTS streaming mode or AudioContext not ready.");
|
546 |
+
}
|
547 |
+
} else {
|
548 |
+
const messageText = event.data;
|
549 |
+
if (messageText.startsWith("USER_TRANSCRIPT:")) {
|
550 |
+
const transcript = messageText.substring("USER_TRANSCRIPT:".length).trim();
|
551 |
+
userTranscriptSpan.textContent = transcript;
|
552 |
+
} else if (messageText.startsWith("ASSISTANT_RESPONSE_TEXT:")) {
|
553 |
+
const llmResponse = messageText.substring("ASSISTANT_RESPONSE_TEXT:".length).trim();
|
554 |
+
assistantTranscriptSpan.textContent = llmResponse;
|
555 |
+
addMessage(`Assistant: ${llmResponse}`, 'assistant');
|
556 |
+
} else if (messageText.startsWith("TTS_STREAM_START:")) {
|
557 |
+
ttsStreaming = true;
|
558 |
+
assistantAudioBufferQueue = [];
|
559 |
+
audioPlaying = false;
|
560 |
+
if (sourceNode) {
|
561 |
+
try { sourceNode.stop(); } catch(e) { console.warn("Error stopping previous sourceNode:", e); }
|
562 |
+
sourceNode = null;
|
563 |
+
}
|
564 |
+
audioPlayer.style.display = 'none';
|
565 |
+
audioPlayer.src = "";
|
566 |
+
try {
|
567 |
+
const paramsText = messageText.substring("TTS_STREAM_START:".length);
|
568 |
+
const params = JSON.parse(paramsText);
|
569 |
+
expectedSampleRate = params.sample_rate;
|
570 |
+
initAudioContext();
|
571 |
+
statusDiv.textContent = 'Status: Receiving audio stream...';
|
572 |
+
addMessage('Assistant (Audio stream starting...)', 'assistant');
|
573 |
+
} catch (e) {
|
574 |
+
console.error("Could not parse TTS_STREAM_START params:", e);
|
575 |
+
statusDiv.textContent = 'Error: Could not parse audio stream parameters.';
|
576 |
+
ttsStreaming = false;
|
577 |
+
}
|
578 |
+
} else if (messageText === "TTS_STREAM_END") {
|
579 |
+
ttsStreaming = false;
|
580 |
+
if (!audioPlaying && assistantAudioBufferQueue.length === 0) {
|
581 |
+
statusDiv.textContent = 'Status: Audio stream finished (or was empty).';
|
582 |
+
} else if (!audioPlaying && assistantAudioBufferQueue.length > 0) {
|
583 |
+
playNextChunkFromQueue();
|
584 |
+
statusDiv.textContent = 'Status: Audio stream finished. Playing remaining...';
|
585 |
+
} else {
|
586 |
+
statusDiv.textContent = 'Status: Audio stream finished. Playing remaining...';
|
587 |
+
}
|
588 |
+
addMessage('Assistant (Audio stream ended)', 'assistant');
|
589 |
+
} else if (messageText.startsWith("SYSTEM_ERROR:")) {
|
590 |
+
const errorMsg = messageText.substring("SYSTEM_ERROR:".length).trim();
|
591 |
+
addMessage(`System Error: ${errorMsg}`, 'system');
|
592 |
+
statusDiv.textContent = `Error: ${errorMsg}`;
|
593 |
+
ttsStreaming = false;
|
594 |
+
assistantAudioBufferQueue = [];
|
595 |
+
} else {
|
596 |
+
addMessage(messageText, 'system');
|
597 |
+
}
|
598 |
+
}
|
599 |
+
};
|
600 |
+
websocket.onerror = (error) => {
|
601 |
+
console.error('WebSocket Error:', error);
|
602 |
+
statusDiv.textContent = 'Status: WebSocket error. Try reconnecting.';
|
603 |
+
addMessage('WebSocket Error. Check console.', 'system');
|
604 |
+
ttsStreaming = false;
|
605 |
+
};
|
606 |
+
|
607 |
+
websocket.onclose = () => {
|
608 |
+
statusDiv.textContent = 'Status: Disconnected. Please refresh to reconnect.';
|
609 |
+
startRecordButton.disabled = true;
|
610 |
+
stopRecordButton.disabled = true;
|
611 |
+
addMessage('Disconnected from server.', 'system');
|
612 |
+
ttsStreaming = false;
|
613 |
+
if (audioContext && audioContext.state !== 'closed') {
|
614 |
+
audioContext.close().catch(e => console.warn("Error closing AudioContext:", e));
|
615 |
+
audioContext = null;
|
616 |
+
console.log("AudioContext closed.");
|
617 |
+
}
|
618 |
+
};
|
619 |
+
}
|
620 |
+
|
621 |
+
function playNextChunkFromQueue() {
|
622 |
+
if (audioPlaying || assistantAudioBufferQueue.length === 0 || !audioContext || audioContext.state !== 'running' || !expectedSampleRate) {
|
623 |
+
if (assistantAudioBufferQueue.length === 0 && !ttsStreaming && !audioPlaying) {
|
624 |
+
console.log("Queue empty, not streaming, not playing: Playback complete.");
|
625 |
+
statusDiv.textContent = 'Status: Audio playback complete.';
|
626 |
+
}
|
627 |
+
return;
|
628 |
+
}
|
629 |
+
audioPlaying = true;
|
630 |
+
|
631 |
+
const pcmDataInt16 = assistantAudioBufferQueue.shift();
|
632 |
+
|
633 |
+
const float32Pcm = new Float32Array(pcmDataInt16.length);
|
634 |
+
for (let i = 0; i < pcmDataInt16.length; i++) {
|
635 |
+
float32Pcm[i] = pcmDataInt16[i] / 32768.0;
|
636 |
+
}
|
637 |
+
|
638 |
+
const audioBuffer = audioContext.createBuffer(1, float32Pcm.length, expectedSampleRate);
|
639 |
+
audioBuffer.getChannelData(0).set(float32Pcm);
|
640 |
+
|
641 |
+
sourceNode = audioContext.createBufferSource();
|
642 |
+
sourceNode.buffer = audioBuffer;
|
643 |
+
sourceNode.connect(audioContext.destination);
|
644 |
+
sourceNode.onended = () => {
|
645 |
+
audioPlaying = false;
|
646 |
+
if (ttsStreaming || assistantAudioBufferQueue.length > 0) {
|
647 |
+
playNextChunkFromQueue();
|
648 |
+
} else {
|
649 |
+
statusDiv.textContent = 'Status: Audio playback finished.';
|
650 |
+
console.log("All queued audio chunks played.");
|
651 |
+
}
|
652 |
+
};
|
653 |
+
sourceNode.start();
|
654 |
+
statusDiv.textContent = 'Status: Playing audio chunk...';
|
655 |
+
}
|
656 |
+
|
657 |
+
function addMessage(text, type) {
|
658 |
+
const messageElement = document.createElement('div');
|
659 |
+
messageElement.classList.add('message', type);
|
660 |
+
messageElement.textContent = text;
|
661 |
+
messagesDiv.appendChild(messageElement);
|
662 |
+
messagesDiv.scrollTop = messagesDiv.scrollHeight;
|
663 |
+
}
|
664 |
+
|
665 |
+
startRecordButton.onclick = async () => {
|
666 |
+
if (!websocket || websocket.readyState !== WebSocket.OPEN) {
|
667 |
+
alert("WebSocket is not connected. Please wait or refresh.");
|
668 |
+
return;
|
669 |
+
}
|
670 |
+
if (audioContext && audioContext.state === 'suspended') {
|
671 |
+
audioContext.resume().catch(e => console.error("Error resuming AudioContext:", e));
|
672 |
+
}
|
673 |
+
initAudioContext();
|
674 |
+
|
675 |
+
try {
|
676 |
+
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
677 |
+
let options = { mimeType: 'audio/webm;codecs=opus' };
|
678 |
+
if (!MediaRecorder.isTypeSupported(options.mimeType)) {
|
679 |
+
console.warn(`${options.mimeType} is not supported, trying default.`);
|
680 |
+
options = {};
|
681 |
+
}
|
682 |
+
mediaRecorder = new MediaRecorder(stream, options);
|
683 |
+
userAudioChunks = [];
|
684 |
+
|
685 |
+
mediaRecorder.ondataavailable = event => {
|
686 |
+
if (event.data.size > 0) userAudioChunks.push(event.data);
|
687 |
+
};
|
688 |
+
|
689 |
+
mediaRecorder.onstop = () => {
|
690 |
+
if (userAudioChunks.length === 0) {
|
691 |
+
console.log("No audio data recorded.");
|
692 |
+
statusDiv.textContent = 'Status: No audio data recorded. Try again.';
|
693 |
+
startRecordButton.disabled = false;
|
694 |
+
stopRecordButton.disabled = true;
|
695 |
+
return;
|
696 |
+
}
|
697 |
+
const audioBlob = new Blob(userAudioChunks, { type: mediaRecorder.mimeType });
|
698 |
+
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
699 |
+
websocket.send(audioBlob);
|
700 |
+
statusDiv.textContent = 'Status: Audio sent. Waiting for response...';
|
701 |
+
} else {
|
702 |
+
statusDiv.textContent = 'Status: WebSocket not open. Cannot send audio.';
|
703 |
+
}
|
704 |
+
userAudioChunks = [];
|
705 |
+
};
|
706 |
+
|
707 |
+
mediaRecorder.start(250);
|
708 |
+
startRecordButton.disabled = true;
|
709 |
+
stopRecordButton.disabled = false;
|
710 |
+
statusDiv.textContent = 'Status: Recording...';
|
711 |
+
userTranscriptSpan.textContent = "...";
|
712 |
+
assistantTranscriptSpan.textContent = "...";
|
713 |
+
audioPlayer.style.display = 'none';
|
714 |
+
audioPlayer.src = '';
|
715 |
+
assistantAudioBufferQueue = [];
|
716 |
+
if (sourceNode) { try {sourceNode.stop();} catch(e){} sourceNode = null; }
|
717 |
+
} catch (err) {
|
718 |
+
console.error('Error accessing microphone:', err);
|
719 |
+
statusDiv.textContent = 'Status: Error accessing microphone.';
|
720 |
+
alert('Could not access microphone: ' + err.message);
|
721 |
+
}
|
722 |
+
};
|
723 |
+
|
724 |
+
stopRecordButton.onclick = () => {
|
725 |
+
if (mediaRecorder && mediaRecorder.state === "recording") {
|
726 |
+
mediaRecorder.stop();
|
727 |
+
startRecordButton.disabled = false;
|
728 |
+
stopRecordButton.disabled = true;
|
729 |
+
}
|
730 |
+
};
|
731 |
+
|
732 |
+
connectWebSocket();
|
733 |
+
</script>
|
734 |
+
</body>
|
735 |
+
</html>
|
736 |
+
"""
|
737 |
+
return HTMLResponse(content=html_content)
|
738 |
+
|
739 |
+
if __name__ == "__main__":
|
740 |
+
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
|
parler-streaming.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import math
|
3 |
+
from queue import Queue
|
4 |
+
from threading import Thread
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
# import spaces
|
9 |
+
import gradio as gr
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from parler_tts import ParlerTTSForConditionalGeneration
|
13 |
+
from pydub import AudioSegment
|
14 |
+
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
|
15 |
+
from transformers.generation.streamers import BaseStreamer
|
16 |
+
|
17 |
+
device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
18 |
+
torch_dtype = torch.float16 if device != "cpu" else torch.float32
|
19 |
+
|
20 |
+
repo_id = "ai4bharat/indic-parler-tts"
|
21 |
+
jenny_repo_id = "ylacombe/parler-tts-mini-jenny-30H"
|
22 |
+
|
23 |
+
model = ParlerTTSForConditionalGeneration.from_pretrained(
|
24 |
+
repo_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
|
25 |
+
).to(device)
|
26 |
+
# jenny_model = ParlerTTSForConditionalGeneration.from_pretrained(
|
27 |
+
# jenny_repo_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
|
28 |
+
# ).to(device)
|
29 |
+
|
30 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_id)
|
31 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
|
32 |
+
|
33 |
+
SAMPLE_RATE = feature_extractor.sampling_rate
|
34 |
+
SEED = 42
|
35 |
+
|
36 |
+
default_text = "Please surprise me and speak in whatever voice you enjoy."
|
37 |
+
examples = [
|
38 |
+
[
|
39 |
+
"Remember - this is only the first iteration of the model! To improve the prosody and naturalness of the speech further, we're scaling up the amount of training data by a factor of five times.",
|
40 |
+
"A male speaker with a low-pitched voice delivering his words at a fast pace in a small, confined space with a very clear audio and an animated tone.",
|
41 |
+
3.0,
|
42 |
+
],
|
43 |
+
[
|
44 |
+
"'This is the best time of my life, Bartley,' she said happily.",
|
45 |
+
"A female speaker with a slightly low-pitched, quite monotone voice delivers her words at a slightly faster-than-average pace in a confined space with very clear audio.",
|
46 |
+
3.0,
|
47 |
+
],
|
48 |
+
[
|
49 |
+
"Montrose also, after having experienced still more variety of good and bad fortune, threw down his arms, and retired out of the kingdom.",
|
50 |
+
"A male speaker with a slightly high-pitched voice delivering his words at a slightly slow pace in a small, confined space with a touch of background noise and a quite monotone tone.",
|
51 |
+
3.0,
|
52 |
+
],
|
53 |
+
[
|
54 |
+
"Montrose also, after having experienced still more variety of good and bad fortune, threw down his arms, and retired out of the kingdom.",
|
55 |
+
"A male speaker with a low-pitched voice delivers his words at a fast pace and an animated tone, in a very spacious environment, accompanied by noticeable background noise.",
|
56 |
+
3.0,
|
57 |
+
],
|
58 |
+
]
|
59 |
+
|
60 |
+
jenny_examples = [
|
61 |
+
[
|
62 |
+
"Remember, this is only the first iteration of the model! To improve the prosody and naturalness of the speech further, we're scaling up the amount of training data by a factor of five times.",
|
63 |
+
"Jenny speaks at an average pace with a slightly animated delivery in a very confined sounding environment with clear audio quality.",
|
64 |
+
3.0,
|
65 |
+
],
|
66 |
+
[
|
67 |
+
"'This is the best time of my life, Bartley,' she said happily.",
|
68 |
+
"Jenny speaks in quite a monotone voice at a slightly faster-than-average pace in a confined space with very clear audio.",
|
69 |
+
3.0,
|
70 |
+
],
|
71 |
+
[
|
72 |
+
"Montrose also, after having experienced still more variety of good and bad fortune, threw down his arms, and retired out of the kingdom.",
|
73 |
+
"Jenny delivers her words at a slightly slow pace in a small, confined space with a touch of background noise and a quite monotone tone.",
|
74 |
+
3.0,
|
75 |
+
],
|
76 |
+
[
|
77 |
+
"Montrose also, after having experienced still more variety of good and bad fortune, threw down his arms, and retired out of the kingdom.",
|
78 |
+
"Jenny delivers her words at a fast pace and an animated tone, in a very spacious environment, accompanied by noticeable background noise.",
|
79 |
+
3.0,
|
80 |
+
],
|
81 |
+
]
|
82 |
+
|
83 |
+
|
84 |
+
class ParlerTTSStreamer(BaseStreamer):
|
85 |
+
def __init__(
|
86 |
+
self,
|
87 |
+
model: ParlerTTSForConditionalGeneration,
|
88 |
+
device: Optional[str] = None,
|
89 |
+
play_steps: Optional[int] = 10,
|
90 |
+
stride: Optional[int] = None,
|
91 |
+
timeout: Optional[float] = None,
|
92 |
+
):
|
93 |
+
"""
|
94 |
+
Streamer that stores playback-ready audio in a queue, to be used by a downstream application as an iterator. This is
|
95 |
+
useful for applications that benefit from accessing the generated audio in a non-blocking way (e.g. in an interactive
|
96 |
+
Gradio demo).
|
97 |
+
Parameters:
|
98 |
+
model (`ParlerTTSForConditionalGeneration`):
|
99 |
+
The Parler-TTS model used to generate the audio waveform.
|
100 |
+
device (`str`, *optional*):
|
101 |
+
The torch device on which to run the computation. If `None`, will default to the device of the model.
|
102 |
+
play_steps (`int`, *optional*, defaults to 10):
|
103 |
+
The number of generation steps with which to return the generated audio array. Using fewer steps will
|
104 |
+
mean the first chunk is ready faster, but will require more codec decoding steps overall. This value
|
105 |
+
should be tuned to your device and latency requirements.
|
106 |
+
stride (`int`, *optional*):
|
107 |
+
The window (stride) between adjacent audio samples. Using a stride between adjacent audio samples reduces
|
108 |
+
the hard boundary between them, giving smoother playback. If `None`, will default to a value equivalent to
|
109 |
+
play_steps // 6 in the audio space.
|
110 |
+
timeout (`int`, *optional*):
|
111 |
+
The timeout for the audio queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
|
112 |
+
in `.generate()`, when it is called in a separate thread.
|
113 |
+
"""
|
114 |
+
self.decoder = model.decoder
|
115 |
+
self.audio_encoder = model.audio_encoder
|
116 |
+
self.generation_config = model.generation_config
|
117 |
+
self.device = device if device is not None else model.device
|
118 |
+
|
119 |
+
# variables used in the streaming process
|
120 |
+
self.play_steps = play_steps
|
121 |
+
if stride is not None:
|
122 |
+
self.stride = stride
|
123 |
+
else:
|
124 |
+
hop_length = math.floor(self.audio_encoder.config.sampling_rate / self.audio_encoder.config.frame_rate)
|
125 |
+
self.stride = hop_length * (play_steps - self.decoder.num_codebooks) // 6
|
126 |
+
self.token_cache = None
|
127 |
+
self.to_yield = 0
|
128 |
+
|
129 |
+
# varibles used in the thread process
|
130 |
+
self.audio_queue = Queue()
|
131 |
+
self.stop_signal = None
|
132 |
+
self.timeout = timeout
|
133 |
+
|
134 |
+
def apply_delay_pattern_mask(self, input_ids):
|
135 |
+
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler)
|
136 |
+
_, delay_pattern_mask = self.decoder.build_delay_pattern_mask(
|
137 |
+
input_ids[:, :1],
|
138 |
+
bos_token_id=self.generation_config.bos_token_id,
|
139 |
+
pad_token_id=self.generation_config.decoder_start_token_id,
|
140 |
+
max_length=input_ids.shape[-1],
|
141 |
+
)
|
142 |
+
# apply the pattern mask to the input ids
|
143 |
+
input_ids = self.decoder.apply_delay_pattern_mask(input_ids, delay_pattern_mask)
|
144 |
+
|
145 |
+
# revert the pattern delay mask by filtering the pad token id
|
146 |
+
mask = (delay_pattern_mask != self.generation_config.bos_token_id) & (delay_pattern_mask != self.generation_config.pad_token_id)
|
147 |
+
input_ids = input_ids[mask].reshape(1, self.decoder.num_codebooks, -1)
|
148 |
+
# append the frame dimension back to the audio codes
|
149 |
+
input_ids = input_ids[None, ...]
|
150 |
+
|
151 |
+
# send the input_ids to the correct device
|
152 |
+
input_ids = input_ids.to(self.audio_encoder.device)
|
153 |
+
|
154 |
+
decode_sequentially = (
|
155 |
+
self.generation_config.bos_token_id in input_ids
|
156 |
+
or self.generation_config.pad_token_id in input_ids
|
157 |
+
or self.generation_config.eos_token_id in input_ids
|
158 |
+
)
|
159 |
+
if not decode_sequentially:
|
160 |
+
output_values = self.audio_encoder.decode(
|
161 |
+
input_ids,
|
162 |
+
audio_scales=[None],
|
163 |
+
)
|
164 |
+
else:
|
165 |
+
sample = input_ids[:, 0]
|
166 |
+
sample_mask = (sample >= self.audio_encoder.config.codebook_size).sum(dim=(0, 1)) == 0
|
167 |
+
sample = sample[:, :, sample_mask]
|
168 |
+
output_values = self.audio_encoder.decode(sample[None, ...], [None])
|
169 |
+
|
170 |
+
audio_values = output_values.audio_values[0, 0]
|
171 |
+
return audio_values.cpu().float().numpy()
|
172 |
+
|
173 |
+
def put(self, value):
|
174 |
+
batch_size = value.shape[0] // self.decoder.num_codebooks
|
175 |
+
if batch_size > 1:
|
176 |
+
raise ValueError("ParlerTTSStreamer only supports batch size 1")
|
177 |
+
|
178 |
+
if self.token_cache is None:
|
179 |
+
self.token_cache = value
|
180 |
+
else:
|
181 |
+
self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1)
|
182 |
+
|
183 |
+
if self.token_cache.shape[-1] % self.play_steps == 0:
|
184 |
+
audio_values = self.apply_delay_pattern_mask(self.token_cache)
|
185 |
+
self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
|
186 |
+
self.to_yield += len(audio_values) - self.to_yield - self.stride
|
187 |
+
|
188 |
+
def end(self):
|
189 |
+
"""Flushes any remaining cache and appends the stop symbol."""
|
190 |
+
if self.token_cache is not None:
|
191 |
+
audio_values = self.apply_delay_pattern_mask(self.token_cache)
|
192 |
+
else:
|
193 |
+
audio_values = np.zeros(self.to_yield)
|
194 |
+
|
195 |
+
self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True)
|
196 |
+
|
197 |
+
def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
|
198 |
+
"""Put the new audio in the queue. If the stream is ending, also put a stop signal in the queue."""
|
199 |
+
self.audio_queue.put(audio, timeout=self.timeout)
|
200 |
+
if stream_end:
|
201 |
+
self.audio_queue.put(self.stop_signal, timeout=self.timeout)
|
202 |
+
|
203 |
+
def __iter__(self):
|
204 |
+
return self
|
205 |
+
|
206 |
+
def __next__(self):
|
207 |
+
value = self.audio_queue.get(timeout=self.timeout)
|
208 |
+
if not isinstance(value, np.ndarray) and value == self.stop_signal:
|
209 |
+
raise StopIteration()
|
210 |
+
else:
|
211 |
+
return value
|
212 |
+
|
213 |
+
def numpy_to_mp3(audio_array, sampling_rate):
|
214 |
+
# Normalize audio_array if it's floating-point
|
215 |
+
if np.issubdtype(audio_array.dtype, np.floating):
|
216 |
+
max_val = np.max(np.abs(audio_array))
|
217 |
+
audio_array = (audio_array / max_val) * 32767 # Normalize to 16-bit range
|
218 |
+
audio_array = audio_array.astype(np.int16)
|
219 |
+
|
220 |
+
# Create an audio segment from the numpy array
|
221 |
+
audio_segment = AudioSegment(
|
222 |
+
audio_array.tobytes(),
|
223 |
+
frame_rate=sampling_rate,
|
224 |
+
sample_width=audio_array.dtype.itemsize,
|
225 |
+
channels=1
|
226 |
+
)
|
227 |
+
|
228 |
+
# Export the audio segment to MP3 bytes - use a high bitrate to maximise quality
|
229 |
+
mp3_io = io.BytesIO()
|
230 |
+
audio_segment.export(mp3_io, format="mp3", bitrate="320k")
|
231 |
+
|
232 |
+
# Get the MP3 bytes
|
233 |
+
mp3_bytes = mp3_io.getvalue()
|
234 |
+
mp3_io.close()
|
235 |
+
|
236 |
+
return mp3_bytes
|
237 |
+
|
238 |
+
sampling_rate = model.audio_encoder.config.sampling_rate
|
239 |
+
frame_rate = model.audio_encoder.config.frame_rate
|
240 |
+
|
241 |
+
# @spaces.GPU
|
242 |
+
def generate_base(text, description, play_steps_in_s=2.0):
|
243 |
+
play_steps = int(frame_rate * play_steps_in_s)
|
244 |
+
streamer = ParlerTTSStreamer(model, device=device, play_steps=play_steps)
|
245 |
+
|
246 |
+
inputs = tokenizer(description, return_tensors="pt").to(device)
|
247 |
+
prompt = tokenizer(text, return_tensors="pt").to(device)
|
248 |
+
|
249 |
+
generation_kwargs = dict(
|
250 |
+
input_ids=inputs.input_ids,
|
251 |
+
prompt_input_ids=prompt.input_ids,
|
252 |
+
streamer=streamer,
|
253 |
+
do_sample=True,
|
254 |
+
temperature=1.0,
|
255 |
+
min_new_tokens=10,
|
256 |
+
)
|
257 |
+
|
258 |
+
set_seed(SEED)
|
259 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
260 |
+
thread.start()
|
261 |
+
|
262 |
+
for new_audio in streamer:
|
263 |
+
print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
|
264 |
+
yield numpy_to_mp3(new_audio, sampling_rate=sampling_rate)
|
265 |
+
|
266 |
+
# @spaces.GPU
|
267 |
+
def generate_jenny(text, description, play_steps_in_s=2.0):
|
268 |
+
play_steps = int(frame_rate * play_steps_in_s)
|
269 |
+
streamer = ParlerTTSStreamer(model, device=device, play_steps=play_steps)
|
270 |
+
|
271 |
+
inputs = tokenizer(description, return_tensors="pt").to(device)
|
272 |
+
prompt = tokenizer(text, return_tensors="pt").to(device)
|
273 |
+
|
274 |
+
generation_kwargs = dict(
|
275 |
+
input_ids=inputs.input_ids,
|
276 |
+
prompt_input_ids=prompt.input_ids,
|
277 |
+
streamer=streamer,
|
278 |
+
do_sample=True,
|
279 |
+
temperature=1.0,
|
280 |
+
min_new_tokens=10,
|
281 |
+
)
|
282 |
+
|
283 |
+
set_seed(SEED)
|
284 |
+
thread = Thread(target=jenny_model.generate, kwargs=generation_kwargs)
|
285 |
+
thread.start()
|
286 |
+
|
287 |
+
for new_audio in streamer:
|
288 |
+
print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
|
289 |
+
yield sampling_rate, new_audio
|
290 |
+
|
291 |
+
|
292 |
+
css = """
|
293 |
+
#share-btn-container {
|
294 |
+
display: flex;
|
295 |
+
padding-left: 0.5rem !important;
|
296 |
+
padding-right: 0.5rem !important;
|
297 |
+
background-color: #000000;
|
298 |
+
justify-content: center;
|
299 |
+
align-items: center;
|
300 |
+
border-radius: 9999px !important;
|
301 |
+
width: 13rem;
|
302 |
+
margin-top: 10px;
|
303 |
+
margin-left: auto;
|
304 |
+
flex: unset !important;
|
305 |
+
}
|
306 |
+
#share-btn {
|
307 |
+
all: initial;
|
308 |
+
color: #ffffff;
|
309 |
+
font-weight: 600;
|
310 |
+
cursor: pointer;
|
311 |
+
font-family: 'IBM Plex Sans', sans-serif;
|
312 |
+
margin-left: 0.5rem !important;
|
313 |
+
padding-top: 0.25rem !important;
|
314 |
+
padding-bottom: 0.25rem !important;
|
315 |
+
right:0;
|
316 |
+
}
|
317 |
+
#share-btn * {
|
318 |
+
all: unset !important;
|
319 |
+
}
|
320 |
+
#share-btn-container div:nth-child(-n+2){
|
321 |
+
width: auto !important;
|
322 |
+
min-height: 0px !important;
|
323 |
+
}
|
324 |
+
#share-btn-container .wrap {
|
325 |
+
display: none !important;
|
326 |
+
}
|
327 |
+
"""
|
328 |
+
with gr.Blocks(css=css) as block:
|
329 |
+
gr.HTML(
|
330 |
+
"""
|
331 |
+
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
|
332 |
+
<div
|
333 |
+
style="
|
334 |
+
display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;
|
335 |
+
"
|
336 |
+
>
|
337 |
+
<h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
|
338 |
+
Parler-TTS 🗣️
|
339 |
+
</h1>
|
340 |
+
</div>
|
341 |
+
</div>
|
342 |
+
"""
|
343 |
+
)
|
344 |
+
gr.HTML(
|
345 |
+
f"""
|
346 |
+
<p><a href="https://github.com/huggingface/parler-tts"> Parler-TTS</a> is a training and inference library for
|
347 |
+
high-fidelity text-to-speech (TTS) models. Two models are demonstrated here, <a href="https://huggingface.co/parler-tts/parler_tts_mini_v0.1"> Parler-TTS Mini v0.1</a>,
|
348 |
+
is the first iteration model trained using 10k hours of narrated audiobooks, and <a href="https://huggingface.co/ylacombe/parler-tts-mini-jenny-30H"> Parler-TTS Jenny</a>,
|
349 |
+
a model fine-tuned on the <a href="https://huggingface.co/datasets/reach-vb/jenny_tts_dataset"> Jenny dataset</a>.
|
350 |
+
Both models generates high-quality speech with features that can be controlled using a simple text prompt (e.g. gender, background noise, speaking rate, pitch and reverberation).</p>
|
351 |
+
<p>Tips for ensuring good generation:
|
352 |
+
<ul>
|
353 |
+
<li>Include the term <b>"very clear audio"</b> to generate the highest quality audio, and "very noisy audio" for high levels of background noise</li>
|
354 |
+
<li>When using the fine-tuned model, include the term <b>"Jenny"</b> to pick out her voice</li>
|
355 |
+
<li>Punctuation can be used to control the prosody of the generations, e.g. use commas to add small breaks in speech</li>
|
356 |
+
<li>The remaining speech features (gender, speaking rate, pitch and reverberation) can be controlled directly through the prompt</li>
|
357 |
+
</ul>
|
358 |
+
</p>
|
359 |
+
"""
|
360 |
+
)
|
361 |
+
with gr.Tab("Base"):
|
362 |
+
with gr.Row():
|
363 |
+
with gr.Column():
|
364 |
+
input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text")
|
365 |
+
description = gr.Textbox(label="Description", lines=2, value="", elem_id="input_description")
|
366 |
+
play_seconds = gr.Slider(3.0, 7.0, value=3.0, step=2, label="Streaming interval in seconds", info="Lower = shorter chunks, lower latency, more codec steps")
|
367 |
+
run_button = gr.Button("Generate Audio", variant="primary")
|
368 |
+
with gr.Column():
|
369 |
+
audio_out = gr.Audio(label="Parler-TTS generation", format="mp3", elem_id="audio_out", streaming=True, autoplay=True)
|
370 |
+
|
371 |
+
inputs = [input_text, description, play_seconds]
|
372 |
+
outputs = [audio_out]
|
373 |
+
gr.Examples(examples=examples, fn=generate_base, inputs=inputs, outputs=outputs, cache_examples=False)
|
374 |
+
run_button.click(fn=generate_base, inputs=inputs, outputs=outputs, queue=True)
|
375 |
+
|
376 |
+
with gr.Tab("Jenny"):
|
377 |
+
with gr.Row():
|
378 |
+
with gr.Column():
|
379 |
+
input_text = gr.Textbox(label="Input Text", lines=2, value=jenny_examples[0][0], elem_id="input_text")
|
380 |
+
description = gr.Textbox(label="Description", lines=2, value=jenny_examples[0][1], elem_id="input_description")
|
381 |
+
play_seconds = gr.Slider(3.0, 7.0, value=jenny_examples[0][2], step=2, label="Streaming interval in seconds", info="Lower = shorter chunks, lower latency, more codec steps")
|
382 |
+
run_button = gr.Button("Generate Audio", variant="primary")
|
383 |
+
with gr.Column():
|
384 |
+
audio_out = gr.Audio(label="Parler-TTS generation", format="mp3", elem_id="audio_out", streaming=True, autoplay=True)
|
385 |
+
|
386 |
+
inputs = [input_text, description, play_seconds]
|
387 |
+
outputs = [audio_out]
|
388 |
+
gr.Examples(examples=jenny_examples, fn=generate_jenny, inputs=inputs, outputs=outputs, cache_examples=False)
|
389 |
+
run_button.click(fn=generate_jenny, inputs=inputs, outputs=outputs, queue=True)
|
390 |
+
|
391 |
+
gr.HTML(
|
392 |
+
"""
|
393 |
+
<p>To improve the prosody and naturalness of the speech further, we're scaling up the amount of training data to 50k hours of speech.
|
394 |
+
The v1 release of the model will be trained on this data, as well as inference optimisations, such as flash attention
|
395 |
+
and torch compile, that will improve the latency by 2-4x. If you want to find out more about how this model was trained and even fine-tune it yourself, check-out the
|
396 |
+
<a href="https://github.com/huggingface/parler-tts"> Parler-TTS</a> repository on GitHub. The Parler-TTS codebase and its
|
397 |
+
associated checkpoints are licensed under <a href='https://github.com/huggingface/parler-tts?tab=Apache-2.0-1-ov-file#readme'> Apache 2.0</a>.</p>
|
398 |
+
"""
|
399 |
+
)
|
400 |
+
|
401 |
+
block.queue()
|
402 |
+
block.launch(share=True)
|
requirements.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# requirements.txt
|
2 |
+
fastapi
|
3 |
+
uvicorn[standard]
|
4 |
+
websockets
|
5 |
+
openai-whisper
|
6 |
+
torch
|
7 |
+
torchaudio
|
8 |
+
transformers
|
9 |
+
accelerate # Often useful for transformers
|
10 |
+
python-multipart # For file uploads in traditional endpoints
|
11 |
+
soundfile # For handling audio files
|
12 |
+
librosa
|
13 |
+
parler-tts # For AI4Bharat's IndicParler-TTS
|
14 |
+
onnx
|
15 |
+
onnxruntime
|
16 |
+
# For specific hardware acceleration (optional, choose based on your setup)
|
17 |
+
# bitsandbytes # For 8-bit quantization of LLM (further RAM reduction)
|
18 |
+
# sentencepiece # Often a dependency for tokenizers
|
19 |
+
# For demo purposes
|
20 |
+
gradio
|
21 |
+
# For Gemini integration
|
22 |
+
google-generativeai
|
streaming_nb.ipynb
ADDED
File without changes
|
test_notebook.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|