Spaces:
Sleeping
Sleeping
Commit
·
941931d
1
Parent(s):
e4558ca
Streaming by default TTS
Browse files- api/audio.py +28 -66
api/audio.py
CHANGED
|
@@ -171,101 +171,63 @@ class STTManager:
|
|
| 171 |
class TTSManager:
|
| 172 |
def __init__(self, config):
|
| 173 |
self.config = config
|
| 174 |
-
self.status = self.test_tts()
|
| 175 |
-
self.streaming = self.
|
| 176 |
-
self.read_last_message = self.rlm_stream if self.streaming else self.rlm
|
| 177 |
|
| 178 |
-
def test_tts(self) -> bool:
|
| 179 |
"""
|
| 180 |
Test if the TTS service is working correctly.
|
| 181 |
-
|
| 182 |
:return: True if the TTS service is working, False otherwise.
|
| 183 |
"""
|
| 184 |
try:
|
| 185 |
-
self.read_text("Handshake")
|
| 186 |
-
return True
|
| 187 |
-
except:
|
| 188 |
-
return False
|
| 189 |
-
|
| 190 |
-
def test_tts_stream(self) -> bool:
|
| 191 |
-
"""
|
| 192 |
-
Test if the TTS streaming service is working correctly.
|
| 193 |
-
|
| 194 |
-
:return: True if the TTS streaming service is working, False otherwise.
|
| 195 |
-
"""
|
| 196 |
-
try:
|
| 197 |
-
for _ in self.read_text_stream("Handshake"):
|
| 198 |
-
pass
|
| 199 |
return True
|
| 200 |
except:
|
| 201 |
return False
|
| 202 |
|
| 203 |
-
def read_text(self, text: str) -> bytes:
|
| 204 |
-
"""
|
| 205 |
-
Convert text to speech and return the audio bytes.
|
| 206 |
-
|
| 207 |
-
:param text: Text to convert to speech.
|
| 208 |
-
:return: Bytes representation of the audio.
|
| 209 |
-
"""
|
| 210 |
-
headers = {"Authorization": "Bearer " + self.config.tts.key}
|
| 211 |
-
try:
|
| 212 |
-
if self.config.tts.type == "OPENAI_API":
|
| 213 |
-
data = {"model": self.config.tts.name, "input": text, "voice": "alloy", "response_format": "opus", "speed": 1.5}
|
| 214 |
-
response = requests.post(self.config.tts.url + "/audio/speech", headers=headers, json=data)
|
| 215 |
-
elif self.config.tts.type == "HF_API":
|
| 216 |
-
response = requests.post(self.config.tts.url, headers=headers, json={"inputs": text})
|
| 217 |
-
if response.status_code != 200:
|
| 218 |
-
error_details = response.json().get("error", "No error message provided")
|
| 219 |
-
raise APIError(f"TTS Error: {self.config.tts.type} error", status_code=response.status_code, details=error_details)
|
| 220 |
-
except APIError:
|
| 221 |
-
raise
|
| 222 |
-
except Exception as e:
|
| 223 |
-
raise APIError(f"TTS Error: Unexpected error: {e}")
|
| 224 |
-
|
| 225 |
-
return response.content
|
| 226 |
-
|
| 227 |
-
def read_text_stream(self, text: str) -> Generator[bytes, None, None]:
|
| 228 |
"""
|
| 229 |
-
Convert text to speech
|
| 230 |
-
|
| 231 |
:param text: Text to convert to speech.
|
|
|
|
| 232 |
:return: Generator yielding chunks of audio bytes.
|
| 233 |
"""
|
| 234 |
-
if
|
| 235 |
-
|
|
|
|
| 236 |
headers = {"Authorization": "Bearer " + self.config.tts.key}
|
| 237 |
data = {"model": self.config.tts.name, "input": text, "voice": "alloy", "response_format": "opus"}
|
| 238 |
|
| 239 |
try:
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
if response.status_code != 200:
|
| 242 |
error_details = response.json().get("error", "No error message provided")
|
| 243 |
-
raise APIError("TTS Error:
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
yield from response.iter_content(chunk_size=1024)
|
| 246 |
-
except StopIteration:
|
| 247 |
-
pass
|
| 248 |
except APIError:
|
| 249 |
raise
|
| 250 |
except Exception as e:
|
| 251 |
raise APIError(f"TTS Error: Unexpected error: {e}")
|
| 252 |
|
| 253 |
-
def
|
| 254 |
"""
|
| 255 |
Read the last message in the chat history and convert it to speech.
|
| 256 |
-
|
| 257 |
-
:param chat_history: List of chat messages.
|
| 258 |
-
:return: Bytes representation of the audio.
|
| 259 |
-
"""
|
| 260 |
-
if len(chat_history) > 0 and chat_history[-1][1]:
|
| 261 |
-
return self.read_text(chat_history[-1][1])
|
| 262 |
-
|
| 263 |
-
def rlm_stream(self, chat_history: List[List[Optional[str]]]) -> Generator[bytes, None, None]:
|
| 264 |
-
"""
|
| 265 |
-
Read the last message in the chat history and convert it to speech using streaming.
|
| 266 |
-
|
| 267 |
:param chat_history: List of chat messages.
|
| 268 |
:return: Generator yielding chunks of audio bytes.
|
| 269 |
"""
|
| 270 |
if len(chat_history) > 0 and chat_history[-1][1]:
|
| 271 |
-
yield from self.
|
|
|
|
| 171 |
class TTSManager:
|
| 172 |
def __init__(self, config):
|
| 173 |
self.config = config
|
| 174 |
+
self.status = self.test_tts(stream=False)
|
| 175 |
+
self.streaming = self.test_tts(stream=True) if self.status else False
|
|
|
|
| 176 |
|
| 177 |
+
def test_tts(self, stream) -> bool:
|
| 178 |
"""
|
| 179 |
Test if the TTS service is working correctly.
|
|
|
|
| 180 |
:return: True if the TTS service is working, False otherwise.
|
| 181 |
"""
|
| 182 |
try:
|
| 183 |
+
list(self.read_text("Handshake", stream=stream))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
return True
|
| 185 |
except:
|
| 186 |
return False
|
| 187 |
|
| 188 |
+
def read_text(self, text: str, stream: Optional[bool] = None) -> Generator[bytes, None, None]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
"""
|
| 190 |
+
Convert text to speech and return the audio bytes, optionally streaming the response.
|
|
|
|
| 191 |
:param text: Text to convert to speech.
|
| 192 |
+
:param stream: Whether to use streaming or not.
|
| 193 |
:return: Generator yielding chunks of audio bytes.
|
| 194 |
"""
|
| 195 |
+
if stream is None:
|
| 196 |
+
stream = self.streaming
|
| 197 |
+
|
| 198 |
headers = {"Authorization": "Bearer " + self.config.tts.key}
|
| 199 |
data = {"model": self.config.tts.name, "input": text, "voice": "alloy", "response_format": "opus"}
|
| 200 |
|
| 201 |
try:
|
| 202 |
+
if not stream:
|
| 203 |
+
if self.config.tts.type == "OPENAI_API":
|
| 204 |
+
response = requests.post(self.config.tts.url + "/audio/speech", headers=headers, json=data)
|
| 205 |
+
elif self.config.tts.type == "HF_API":
|
| 206 |
+
response = requests.post(self.config.tts.url, headers=headers, json={"inputs": text})
|
| 207 |
+
|
| 208 |
if response.status_code != 200:
|
| 209 |
error_details = response.json().get("error", "No error message provided")
|
| 210 |
+
raise APIError(f"TTS Error: {self.config.tts.type} error", status_code=response.status_code, details=error_details)
|
| 211 |
+
yield response.content
|
| 212 |
+
else:
|
| 213 |
+
if self.config.tts.type != "OPENAI_API":
|
| 214 |
+
raise APIError("TTS Error: Streaming not supported for this TTS type")
|
| 215 |
+
|
| 216 |
+
with requests.post(self.config.tts.url + "/audio/speech", headers=headers, json=data, stream=True) as response:
|
| 217 |
+
if response.status_code != 200:
|
| 218 |
+
error_details = response.json().get("error", "No error message provided")
|
| 219 |
+
raise APIError("TTS Error: OPENAI API error", status_code=response.status_code, details=error_details)
|
| 220 |
yield from response.iter_content(chunk_size=1024)
|
|
|
|
|
|
|
| 221 |
except APIError:
|
| 222 |
raise
|
| 223 |
except Exception as e:
|
| 224 |
raise APIError(f"TTS Error: Unexpected error: {e}")
|
| 225 |
|
| 226 |
+
def read_last_message(self, chat_history: List[List[Optional[str]]]) -> Generator[bytes, None, None]:
|
| 227 |
"""
|
| 228 |
Read the last message in the chat history and convert it to speech.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
:param chat_history: List of chat messages.
|
| 230 |
:return: Generator yielding chunks of audio bytes.
|
| 231 |
"""
|
| 232 |
if len(chat_history) > 0 and chat_history[-1][1]:
|
| 233 |
+
yield from self.read_text(chat_history[-1][1])
|