demorrha / core /text_to_speech.py
rick
patch 1.2.12
026b176 unverified
raw
history blame
4.43 kB
#coding: utf-8
import os
import tempfile
#from typing import Any
#from typing import Dict
#from typing import IO
#from typing import List
from typing import Optional
from typing import Tuple
#from typing import Union
from base64 import b64encode
from openai import OpenAI
from pydub import AudioSegment
import streamlit as st
#from dotenv import load_dotenv
# Charger les variables d'environnement depuis le fichier .env
#load_dotenv()
class openai_tts(object):
def __init__(self,
tts_voice: Optional[str] = "nova",
tts_model: Optional[str] = "tts-1",
response_format: Optional[str] = "mp3",
speed: Optional[float] = 1.0
):
self.client = None
self.init_supported_formats__()
self.init_api_client()
if response_format:
self.set_response_format(response_format)
if tts_voice:
self.set_tts_voice(tts_voice)
if tts_model:
self.set_tts_model(tts_model)
if speed:
self.set_tts_speed(speed)
def set_tts_speed(self, speed):
if not (0.25 <= speed <= 4.0):
raise ValueError(f"[TTS] - Speed must be between 0.25 and 4.0. Provided value: {speed}")
else:
self.speed = speed
return self
def set_tts_voice(self, voice):
voix_valides = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
if voice not in voix_valides:
raise ValueError(f"[TTS] - Invalid TTS voice: {voice}. Valid voices are: {', '.join(voix_valides)}.")
else:
self.tts_voice = voice
return self
def set_tts_model(self, model):
if model not in ["tts-1", "tts-1-hd"]:
raise ValueError(f"[TTS] - Invalid TTS model: {model}. Valid models are 'tts-1' and 'tts-1-hd'.")
else:
self.tts_model = model
return self
def init_supported_formats__(self):
self.supported_formats = [ 'mp3', 'opus', 'aac', 'flac', 'wav', 'pcm' ]
return self
def set_response_format(self, format: str):
if format not in self.supported_formats:
raise ValueError(f"[TTS] - Unsupported format: {format}. Supported formats are: {', '.join(self.supported_formats)}")
else:
self.response_format = format
return self
def init_api_client(self):
if not (self.client):
# OpenAI client configuration with API key
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
return self
def text_to_speech(self,
input_text: str) -> Tuple[Optional[bytes], float]:
"""
Convertit du texte en parole en utilisant l'API OpenAI.
Args:
input_text (str): Le texte à convertir en parole.
Returns:
Dict[str, Union[float, str]]: Un dictionnaire contenant:
- 'audio_duration' (float): La durée de l'audio en secondes.
- 'data_bytes' (str): Les données audio encodées en base64.
"""
response = self.client.audio.speech.create(
model=self.tts_model,
voice=self.tts_voice,
input=input_text,
response_format=self.response_format,
speed=self.speed
)
data_output = response.read()
tmp_file = tempfile.TemporaryFile()
tmp_file.write(data_output)
tmp_file.seek(0)
audio = AudioSegment.from_file(tmp_file, format=self.response_format)
duration = len(audio) / 1000
tmp_file.close()
return {
"audio_duration": duration,
"data_bytes": b64encode(data_output).decode()
}
def process_tts_message(text_response: str) -> Tuple[Optional[bytes], Optional[float]]:
try:
tts_output_ = openai_tts(
tts_voice=st.session_state.tts_voice,
tts_model="tts-1",
response_format="mp3",
speed=1.0
).text_to_speech(text_response)
return tts_output_["data_bytes"], tts_output_["audio_duration"]
except Exception as e:
st.error(f"Une erreur s'est produite lors de la conversion texte-parole : {e}")
return None, None
"""
if __name__ == "__main__":
openai_tts().text_to_speech("Hello, I am an AI assistant. How can I help you?")
"""